Skip to content

Commit

Permalink
feat: Add check-and-pull RPC function
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Sep 2, 2024
1 parent 5833fd0 commit c26d5ee
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 16 deletions.
41 changes: 41 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,18 @@
from ai.backend.common.docker import ImageRef
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.events import (
ImagePullFinishedEvent,
ImagePullStartedEvent,
KernelLifecycleEventReason,
KernelTerminatedEvent,
)
from ai.backend.common.types import (
AutoPullBehavior,
ClusterInfo,
CommitStatus,
HardwareMetadata,
HostPortPair,
ImageConfig,
ImageRegistry,
KernelCreationConfig,
KernelId,
Expand Down Expand Up @@ -478,6 +482,43 @@ async def sync_kernel_registry(
suppress_events=True,
)

@rpc_function
@collect_error
async def check_and_pull(
self,
image_config: Mapping[str, Any],
) -> dict[str, str]:
"""
Check whether the agent has an image.
Spawn a bgtask that pulls the specified image and return bgtask ID.
"""
img_conf = cast(ImageConfig, image_config)
img_ref = ImageRef.from_image_config(img_conf)

bgtask_mgr = self.agent.background_task_manager

async def _pull(reporter: ProgressReporter) -> None:
need_to_pull = await self.agent.check_image(
img_ref, img_conf["digest"], AutoPullBehavior(img_conf["auto_pull"])
)
if need_to_pull:
await self.agent.produce_event(
ImagePullStartedEvent(
image=str(img_ref),
)
)
await self.agent.pull_image(img_ref, img_conf["registry"])
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
)
)

task_id = await bgtask_mgr.start(_pull)
return {
"bgtask_id": str(task_id),
}

@rpc_function
@collect_error
async def create_kernels(
Expand Down
23 changes: 23 additions & 0 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,28 @@ def deserialize(cls, value: tuple):
)


@attrs.define(slots=True, frozen=True)
class ImagePullEventArgs:
image: str = attrs.field()

def serialize(self) -> tuple:
return (self.image,)

@classmethod
def deserialize(cls, value: tuple):
return cls(
image=value[0],
)


class ImagePullStartedEvent(ImagePullEventArgs, AbstractEvent):
name = "image_pull_started"


class ImagePullFinishedEvent(ImagePullEventArgs, AbstractEvent):
name = "image_pull_finished"


class KernelLifecycleEventReason(enum.StrEnum):
AGENT_TERMINATION = "agent-termination"
ALREADY_TERMINATED = "already-terminated"
Expand Down Expand Up @@ -285,6 +307,7 @@ class KernelPullingEvent(KernelCreationEventArgs, AbstractEvent):
name = "kernel_pulling"


# TODO: Remove this event
@attrs.define(auto_attribs=True, slots=True)
class KernelPullProgressEvent(AbstractEvent):
name = "kernel_pull_progress"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
KERNEL_STATUS_ENUM_NAME = "kernelstatus"
SESSION_STATUS_ENUM_NAME = "sessionstatus"

NEW_STATUS_NAME = "READY_TO_CREATE"
NEW_STATUS_NAME = "READY_TO_START"


class OldKernelStatus(enum.Enum):
Expand Down
19 changes: 14 additions & 5 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import enum
import logging
import uuid
from collections.abc import Mapping
from collections.abc import Container, Mapping
from contextlib import asynccontextmanager as actxmgr
from datetime import datetime
from typing import (
Expand Down Expand Up @@ -120,7 +120,7 @@ class KernelStatus(enum.Enum):
# ---
BUILDING = 20
PULLING = 21
READY_TO_CREATE = 22
READY_TO_START = 22
# ---
RUNNING = 30
RESTARTING = 31
Expand Down Expand Up @@ -221,19 +221,19 @@ def default_hostname(context) -> str:
},
KernelStatus.SCHEDULED: {
KernelStatus.PULLING,
KernelStatus.READY_TO_CREATE,
KernelStatus.READY_TO_START,
KernelStatus.PREPARING, # TODO: Delete this after applying check-and-pull API
KernelStatus.CANCELLED,
KernelStatus.ERROR,
},
KernelStatus.PULLING: {
KernelStatus.READY_TO_CREATE,
KernelStatus.READY_TO_START,
KernelStatus.PREPARING, # TODO: Delete this after applying check-and-pull API
KernelStatus.RUNNING, # TODO: Delete this after applying check-and-pull API
KernelStatus.CANCELLED,
KernelStatus.ERROR,
},
KernelStatus.READY_TO_CREATE: {
KernelStatus.READY_TO_START: {
KernelStatus.PREPARING,
KernelStatus.CANCELLED,
KernelStatus.ERROR,
Expand Down Expand Up @@ -657,6 +657,15 @@ async def get_kernel_to_update_status(
raise KernelNotFound(f"Kernel not found (id:{kernel_id})")
return kernel_row

@classmethod
async def get_bulk_kernels_to_update_status(
cls,
db_session: SASession,
kernel_ids: Container[KernelId],
) -> list[KernelRow]:
_stmt = sa.select(KernelRow).where(KernelRow.id.in_(kernel_ids))
return (await db_session.scalars(_stmt)).all()

def transit_status(
self,
status: KernelStatus,
Expand Down
20 changes: 10 additions & 10 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class SessionStatus(enum.Enum):
# manager can set PENDING and SCHEDULED independently
# ---
PULLING = 7
READY_TO_CREATE = 9
READY_TO_START = 9
PREPARING = 10
# ---
RUNNING = 30
Expand All @@ -150,7 +150,7 @@ class SessionStatus(enum.Enum):

FOLLOWING_SESSION_STATUSES = (
# Session statuses that need to wait all kernels belonging to the session
SessionStatus.READY_TO_CREATE,
SessionStatus.READY_TO_START,
SessionStatus.RUNNING,
SessionStatus.TERMINATED,
)
Expand Down Expand Up @@ -214,7 +214,7 @@ class SessionStatus(enum.Enum):
KernelStatus.PREPARING: SessionStatus.PREPARING,
KernelStatus.BUILDING: SessionStatus.PREPARING,
KernelStatus.PULLING: SessionStatus.PULLING,
KernelStatus.READY_TO_CREATE: SessionStatus.READY_TO_CREATE,
KernelStatus.READY_TO_START: SessionStatus.READY_TO_START,
KernelStatus.RUNNING: SessionStatus.RUNNING,
KernelStatus.RESTARTING: SessionStatus.RESTARTING,
KernelStatus.RESIZING: SessionStatus.RUNNING,
Expand All @@ -230,7 +230,7 @@ class SessionStatus(enum.Enum):
SessionStatus.SCHEDULED: KernelStatus.SCHEDULED,
SessionStatus.PREPARING: KernelStatus.PREPARING,
SessionStatus.PULLING: KernelStatus.PULLING,
SessionStatus.READY_TO_CREATE: KernelStatus.READY_TO_CREATE,
SessionStatus.READY_TO_START: KernelStatus.READY_TO_START,
SessionStatus.RUNNING: KernelStatus.RUNNING,
SessionStatus.RESTARTING: KernelStatus.RESTARTING,
SessionStatus.TERMINATING: KernelStatus.TERMINATING,
Expand All @@ -247,19 +247,19 @@ class SessionStatus(enum.Enum):
},
SessionStatus.SCHEDULED: {
SessionStatus.PULLING,
SessionStatus.READY_TO_CREATE,
SessionStatus.READY_TO_START,
SessionStatus.PREPARING, # TODO: Delete this after applying check-and-pull API
SessionStatus.ERROR,
SessionStatus.CANCELLED,
},
SessionStatus.PULLING: {
SessionStatus.READY_TO_CREATE,
SessionStatus.READY_TO_START,
SessionStatus.PREPARING, # TODO: Delete this after applying check-and-pull API
SessionStatus.RUNNING, # TODO: Delete this after applying check-and-pull API
SessionStatus.ERROR,
SessionStatus.CANCELLED,
},
SessionStatus.READY_TO_CREATE: {
SessionStatus.READY_TO_START: {
SessionStatus.PREPARING,
SessionStatus.ERROR,
SessionStatus.CANCELLED,
Expand Down Expand Up @@ -315,7 +315,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta
case (
KernelStatus.PENDING
| KernelStatus.SCHEDULED
| KernelStatus.READY_TO_CREATE
| KernelStatus.READY_TO_START
| KernelStatus.SUSPENDED
| KernelStatus.TERMINATED
| KernelStatus.CANCELLED
Expand All @@ -339,7 +339,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta
pass
case (
KernelStatus.SCHEDULED
| KernelStatus.READY_TO_CREATE
| KernelStatus.READY_TO_START
| KernelStatus.PREPARING
| KernelStatus.BUILDING
| KernelStatus.PULLING
Expand All @@ -359,7 +359,7 @@ def determine_session_status(sibling_kernels: Sequence[KernelRow]) -> SessionSta
match k.status:
case (
KernelStatus.PENDING
| KernelStatus.READY_TO_CREATE
| KernelStatus.READY_TO_START
| KernelStatus.SCHEDULED
| KernelStatus.PREPARING
| KernelStatus.BUILDING
Expand Down
Loading

0 comments on commit c26d5ee

Please sign in to comment.