From 2dbab118bac2df99a2224c41a1e8f24a8b5495c5 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sat, 10 Aug 2024 23:29:08 +0900 Subject: [PATCH] feat: Add check-and-pull RPC function --- src/ai/backend/agent/server.py | 41 ++++++ src/ai/backend/common/events.py | 23 ++++ ...67d26e_update_session_and_kernel_status.py | 2 +- src/ai/backend/manager/models/kernel.py | 19 ++- src/ai/backend/manager/models/session.py | 20 +-- src/ai/backend/manager/registry.py | 129 ++++++++++++++++++ 6 files changed, 218 insertions(+), 16 deletions(-) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 2f5159feb47..ce48ee3692b 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -53,14 +53,18 @@ from ai.backend.common.docker import ImageRef from ai.backend.common.etcd import AsyncEtcd, ConfigScopes from ai.backend.common.events import ( + ImagePullFinishedEvent, + ImagePullStartedEvent, KernelLifecycleEventReason, KernelTerminatedEvent, ) from ai.backend.common.types import ( + AutoPullBehavior, ClusterInfo, CommitStatus, HardwareMetadata, HostPortPair, + ImageConfig, ImageRegistry, KernelCreationConfig, KernelId, @@ -478,6 +482,43 @@ async def sync_kernel_registry( suppress_events=True, ) + @rpc_function + @collect_error + async def check_and_pull( + self, + image_config: Mapping[str, Any], + ) -> dict[str, str]: + """ + Check whether the agent has an image. + Spawn a bgtask that pulls the specified image and return bgtask ID. + """ + img_conf = cast(ImageConfig, image_config) + img_ref = ImageRef.from_image_config(img_conf) + + bgtask_mgr = self.agent.background_task_manager + + async def _pull(reporter: ProgressReporter) -> None: + need_to_pull = await self.agent.check_image( + img_ref, img_conf["digest"], AutoPullBehavior(img_conf["auto_pull"]) + ) + if need_to_pull: + await self.agent.produce_event( + ImagePullStartedEvent( + image=str(img_ref), + ) + ) + await self.agent.pull_image(img_ref, img_conf["registry"]) + await self.agent.produce_event( + ImagePullFinishedEvent( + image=str(img_ref), + ) + ) + + task_id = await bgtask_mgr.start(_pull) + return { + "bgtask_id": str(task_id), + } + @rpc_function @collect_error async def create_kernels( diff --git a/src/ai/backend/common/events.py b/src/ai/backend/common/events.py index 6dc7a7e477d..d41d713047c 100644 --- a/src/ai/backend/common/events.py +++ b/src/ai/backend/common/events.py @@ -210,6 +210,28 @@ def deserialize(cls, value: tuple): ) +@attrs.define(slots=True, frozen=True) +class ImagePullEventArgs: + image: str = attrs.field() + + def serialize(self) -> tuple: + return (self.image,) + + @classmethod + def deserialize(cls, value: tuple): + return cls( + image=value[0], + ) + + +class ImagePullStartedEvent(ImagePullEventArgs, AbstractEvent): + name = "image_pull_started" + + +class ImagePullFinishedEvent(ImagePullEventArgs, AbstractEvent): + name = "image_pull_finished" + + class KernelLifecycleEventReason(enum.StrEnum): AGENT_TERMINATION = "agent-termination" ALREADY_TERMINATED = "already-terminated" @@ -285,6 +307,7 @@ class KernelPullingEvent(KernelCreationEventArgs, AbstractEvent): name = "kernel_pulling" +# TODO: Remove this event @attrs.define(auto_attribs=True, slots=True) class KernelPullProgressEvent(AbstractEvent): name = "kernel_pull_progress" diff --git a/src/ai/backend/manager/models/alembic/versions/6e44ea67d26e_update_session_and_kernel_status.py b/src/ai/backend/manager/models/alembic/versions/6e44ea67d26e_update_session_and_kernel_status.py index bbdc9bda2c1..80647a761af 100644 --- a/src/ai/backend/manager/models/alembic/versions/6e44ea67d26e_update_session_and_kernel_status.py +++ b/src/ai/backend/manager/models/alembic/versions/6e44ea67d26e_update_session_and_kernel_status.py @@ -20,7 +20,7 @@ KERNEL_STATUS_ENUM_NAME = "kernelstatus" SESSION_STATUS_ENUM_NAME = "sessionstatus" -NEW_STATUS_NAME = "READY_TO_CREATE" +NEW_STATUS_NAME = "READY_TO_START" class OldKernelStatus(enum.Enum): diff --git a/src/ai/backend/manager/models/kernel.py b/src/ai/backend/manager/models/kernel.py index dea24ec2d3b..fb5d1318b90 100644 --- a/src/ai/backend/manager/models/kernel.py +++ b/src/ai/backend/manager/models/kernel.py @@ -4,7 +4,7 @@ import enum import logging import uuid -from collections.abc import Mapping +from collections.abc import Container, Mapping from contextlib import asynccontextmanager as actxmgr from datetime import datetime from typing import ( @@ -120,7 +120,7 @@ class KernelStatus(enum.Enum): # --- BUILDING = 20 PULLING = 21 - READY_TO_CREATE = 22 + READY_TO_START = 22 # --- RUNNING = 30 RESTARTING = 31 @@ -221,19 +221,19 @@ def default_hostname(context) -> str: }, KernelStatus.SCHEDULED: { KernelStatus.PULLING, - KernelStatus.READY_TO_CREATE, + KernelStatus.READY_TO_START, KernelStatus.PREPARING, # TODO: Delete this after applying check-and-pull API KernelStatus.CANCELLED, KernelStatus.ERROR, }, KernelStatus.PULLING: { - KernelStatus.READY_TO_CREATE, + KernelStatus.READY_TO_START, KernelStatus.PREPARING, # TODO: Delete this after applying check-and-pull API KernelStatus.RUNNING, # TODO: Delete this after applying check-and-pull API KernelStatus.CANCELLED, KernelStatus.ERROR, }, - KernelStatus.READY_TO_CREATE: { + KernelStatus.READY_TO_START: { KernelStatus.PREPARING, KernelStatus.CANCELLED, KernelStatus.ERROR, @@ -657,6 +657,15 @@ async def get_kernel_to_update_status( raise KernelNotFound(f"Kernel not found (id:{kernel_id})") return kernel_row + @classmethod + async def get_bulk_kernels_to_update_status( + cls, + db_session: SASession, + kernel_ids: Container[KernelId], + ) -> list[KernelRow]: + _stmt = sa.select(KernelRow).where(KernelRow.id.in_(kernel_ids)) + return (await db_session.scalars(_stmt)).all() + def transit_status( self, status: KernelStatus, diff --git a/src/ai/backend/manager/models/session.py b/src/ai/backend/manager/models/session.py index de1d30083d4..cfb03175e79 100644 --- a/src/ai/backend/manager/models/session.py +++ b/src/ai/backend/manager/models/session.py @@ -135,7 +135,7 @@ class SessionStatus(enum.Enum): # manager can set PENDING and SCHEDULED independently # --- PULLING = 7 - READY_TO_CREATE = 9 + READY_TO_START = 9 PREPARING = 10 # --- RUNNING = 30 @@ -150,7 +150,7 @@ class SessionStatus(enum.Enum): FOLLOWING_SESSION_STATUSES = ( # Session statuses that need to wait all kernels belonging to the session - SessionStatus.READY_TO_CREATE, + SessionStatus.READY_TO_START, SessionStatus.RUNNING, SessionStatus.TERMINATED, ) @@ -214,7 +214,7 @@ class SessionStatus(enum.Enum): KernelStatus.PREPARING: SessionStatus.PREPARING, KernelStatus.BUILDING: SessionStatus.PREPARING, KernelStatus.PULLING: SessionStatus.PULLING, - KernelStatus.READY_TO_CREATE: SessionStatus.READY_TO_CREATE, + KernelStatus.READY_TO_START: SessionStatus.READY_TO_START, KernelStatus.RUNNING: SessionStatus.RUNNING, KernelStatus.RESTARTING: SessionStatus.RESTARTING, KernelStatus.RESIZING: SessionStatus.RUNNING, @@ -230,7 +230,7 @@ class SessionStatus(enum.Enum): SessionStatus.SCHEDULED: KernelStatus.SCHEDULED, SessionStatus.PREPARING: KernelStatus.PREPARING, SessionStatus.PULLING: KernelStatus.PULLING, - SessionStatus.READY_TO_CREATE: KernelStatus.READY_TO_CREATE, + SessionStatus.READY_TO_START: KernelStatus.READY_TO_START, SessionStatus.RUNNING: KernelStatus.RUNNING, SessionStatus.RESTARTING: KernelStatus.RESTARTING, SessionStatus.TERMINATING: KernelStatus.TERMINATING, @@ -247,19 +247,19 @@ class SessionStatus(enum.Enum): }, SessionStatus.SCHEDULED: { SessionStatus.PULLING, - SessionStatus.READY_TO_CREATE, + SessionStatus.READY_TO_START, SessionStatus.PREPARING, # TODO: Delete this after applying check-and-pull API SessionStatus.ERROR, SessionStatus.CANCELLED, }, SessionStatus.PULLING: { - SessionStatus.READY_TO_CREATE, + SessionStatus.READY_TO_START, SessionStatus.PREPARING, # TODO: Delete this after applying check-and-pull API SessionStatus.RUNNING, # TODO: Delete this after applying check-and-pull API SessionStatus.ERROR, SessionStatus.CANCELLED, }, - SessionStatus.READY_TO_CREATE: { + SessionStatus.READY_TO_START: { SessionStatus.PREPARING, SessionStatus.ERROR, SessionStatus.CANCELLED, @@ -315,7 +315,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta case ( KernelStatus.PENDING | KernelStatus.SCHEDULED - | KernelStatus.READY_TO_CREATE + | KernelStatus.READY_TO_START | KernelStatus.SUSPENDED | KernelStatus.TERMINATED | KernelStatus.CANCELLED @@ -339,7 +339,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta pass case ( KernelStatus.SCHEDULED - | KernelStatus.READY_TO_CREATE + | KernelStatus.READY_TO_START | KernelStatus.PREPARING | KernelStatus.BUILDING | KernelStatus.PULLING @@ -359,7 +359,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta match k.status: case ( KernelStatus.PENDING - | KernelStatus.READY_TO_CREATE + | KernelStatus.READY_TO_START | KernelStatus.SCHEDULED | KernelStatus.PREPARING | KernelStatus.BUILDING diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 30a2dc51015..5051f676d22 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -60,6 +60,8 @@ DoAgentResourceCheckEvent, DoSyncKernelLogsEvent, DoTerminateSessionEvent, + ImagePullFinishedEvent, + ImagePullStartedEvent, KernelCancelledEvent, KernelCreatingEvent, KernelLifecycleEventReason, @@ -1316,6 +1318,78 @@ async def _post_enqueue() -> None: ) return session_id + async def _check_and_pull_in_one_agent( + self, + agent_alloc_ctx: AgentAllocationContext, + kernel_agent_bindings: Sequence[KernelAgentBinding], + image_configs: Mapping[str, ImageConfig], + ) -> dict[str, uuid.UUID]: + """ + Return {str(ImageRef): bgtask_id} + """ + assert agent_alloc_ctx.agent_id is not None + + result: dict[str, uuid.UUID] = {} + async with self.agent_cache.rpc_context( + agent_alloc_ctx.agent_id, + ) as rpc: + for img, conf in image_configs.items(): + resp = cast(dict[str, str], await rpc.call.check_and_pull(conf)) + bgtask_id = resp["bgtask_id"] + result[img] = uuid.UUID(bgtask_id) + + return result + + async def check_before_start( + self, + scheduled_session: SessionRow, + ) -> None: + kernel_agent_bindings: list[KernelAgentBinding] = [ + KernelAgentBinding( + kernel=k, + agent_alloc_ctx=AgentAllocationContext( + agent_id=k.agent, + agent_addr=k.agent_addr, + scaling_group=scheduled_session.scaling_group, + ), + allocated_host_ports=set(), + ) + for k in scheduled_session.kernels + ] + + # Aggregate image registry information + _image_refs: set[ImageRef] = set([item.kernel.image_ref for item in kernel_agent_bindings]) + auto_pull = cast(str, self.shared_config["docker"]["image"]["auto_pull"]) + async with self.db.begin_readonly_session() as db_session: + configs = await bulk_get_image_configs( + db_session, self.shared_config.etcd, _image_refs, AutoPullBehavior(auto_pull) + ) + img_ref_to_conf_map = {ImageRef.from_image_config(item): item for item in configs} + + def _keyfunc(binding: KernelAgentBinding) -> AgentId: + assert ( + binding.agent_alloc_ctx.agent_id is not None + ), f"No agent assigned to kernel (k:{binding.kernel.id})" + return binding.agent_alloc_ctx.agent_id + + async with aiotools.PersistentTaskGroup() as tg: + for agent_id, group_iterator in itertools.groupby( + sorted(kernel_agent_bindings, key=_keyfunc), + key=_keyfunc, + ): + items: list[KernelAgentBinding] = [*group_iterator] + # Within a group, agent_alloc_ctx are same. + agent_alloc_ctx = items[0].agent_alloc_ctx + _filtered_imgs: set[ImageRef] = {binding.kernel.image_ref for binding in items} + _img_conf_map = { + str(img): conf + for img, conf in img_ref_to_conf_map.items() + if img in _filtered_imgs + } + tg.create_task( + self._check_and_pull_in_one_agent(agent_alloc_ctx, items, _img_conf_map) + ) + async def start_session( self, sched_ctx: SchedulingContext, @@ -3059,6 +3133,47 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None: (str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels ]) + async def mark_image_pull_started( + self, + db_conn: SAConnection, + image: str, + ) -> None: + session_ids: list[SessionId] = [] + + async def _transit(db_session: AsyncSession) -> None: + _stmt = sa.select(KernelRow).where( + (KernelRow.image == image) & (KernelRow.status == KernelStatus.SCHEDULED) + ) + for row in await db_session.scalars(_stmt): + kernel_row = cast(KernelRow, row) + is_pulling = kernel_row.transit_status(KernelStatus.PULLING) + if is_pulling: + session_ids.append(kernel_row.session_id) + + await execute_with_txn_retry(_transit, self.db.begin_session, db_conn) + await self.set_status_updatable_session(session_ids) + + async def mark_image_pull_finished( + self, + db_conn: SAConnection, + image: str, + ) -> None: + session_ids: list[SessionId] = [] + + async def _transit(db_session: AsyncSession) -> None: + _stmt = sa.select(KernelRow).where( + (KernelRow.image == image) + & (KernelRow.status.in_((KernelStatus.SCHEDULED, KernelStatus.PULLING))) + ) + for row in await db_session.scalars(_stmt): + kernel_row = cast(KernelRow, row) + is_ready = kernel_row.transit_status(KernelStatus.READY_TO_START) + if is_ready: + session_ids.append(kernel_row.session_id) + + await execute_with_txn_retry(_transit, self.db.begin_session, db_conn) + await self.set_status_updatable_session(session_ids) + async def mark_kernel_preparing( self, db_conn: SAConnection, @@ -3554,6 +3669,20 @@ async def delete_appproxy_endpoint(self, db_sess: AsyncSession, endpoint: Endpoi pass +async def handle_image_lifecycle( + context: AgentRegistry, + agent_id: AgentId, + event: (ImagePullStartedEvent | ImagePullFinishedEvent), +) -> None: + match event: + case ImagePullStartedEvent(image): + async with context.db.connect() as db_conn: + await context.mark_image_pull_started(db_conn, image) + case ImagePullFinishedEvent(image): + async with context.db.connect() as db_conn: + await context.mark_image_pull_finished(db_conn, image) + + async def handle_kernel_creation_lifecycle( context: AgentRegistry, source: AgentId,