Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement check-and-pull mechanism for session creation lifecycle #2721

1 change: 1 addition & 0 deletions changes/2721.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Check if agent has the required image before creating compute kernels
64 changes: 58 additions & 6 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from datetime import datetime, timezone
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from pathlib import Path
Expand Down Expand Up @@ -53,6 +54,7 @@
from ai.backend.common.docker import ImageRef
from ai.backend.common.etcd import AsyncEtcd, ConfigScopes
from ai.backend.common.events import (
ImagePullFailedEvent,
ImagePullFinishedEvent,
ImagePullStartedEvent,
KernelLifecycleEventReason,
Expand Down Expand Up @@ -492,6 +494,18 @@ async def check_and_pull(
Check whether the agent has an image.
Spawn a bgtask that pulls the specified image and return bgtask ID.
"""
log.info(
"rpc::check_and_pull(images:{0})",
[
{
"name": conf["canonical"],
"project": conf["project"],
"registry": conf["registry"]["name"],
}
for conf in image_configs.values()
],
)

bgtask_mgr = self.agent.background_task_manager

async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
Expand All @@ -500,18 +514,56 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
img_ref, img_conf["digest"], AutoPullBehavior(img_conf["auto_pull"])
)
if need_to_pull:
log.info(f"rpc::check_and_pull() start pulling {str(img_ref)}")
await self.agent.produce_event(
ImagePullStartedEvent(image=str(img_ref), agent_id=self.agent.id)
ImagePullStartedEvent(
image=str(img_ref),
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
)
)
image_pull_timeout = cast(
Optional[float], self.local_config["agent"]["api"]["pull-timeout"]
)
await self.agent.pull_image(
img_ref, img_conf["registry"], timeout=image_pull_timeout
try:
await self.agent.pull_image(
img_ref, img_conf["registry"], timeout=image_pull_timeout
)
except asyncio.TimeoutError:
log.exception(f"Image pull timeout after {image_pull_timeout} sec")
await self.agent.produce_event(
ImagePullFailedEvent(
image=str(img_ref),
agent_id=self.agent.id,
msg=f"timeout (s:{image_pull_timeout})",
)
)
except Exception as e:
log.exception(f"Image pull failed (e:{repr(e)})")
await self.agent.produce_event(
ImagePullFailedEvent(
image=str(img_ref),
agent_id=self.agent.id,
msg=repr(e),
)
)
else:
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
)
)
else:
await self.agent.produce_event(
ImagePullFinishedEvent(
image=str(img_ref),
agent_id=self.agent.id,
timestamp=datetime.now(timezone.utc).timestamp(),
msg="Image already exists",
)
)
await self.agent.produce_event(
ImagePullFinishedEvent(image=str(img_ref), agent_id=self.agent.id)
)

ret: dict[str, str] = {}
for img, img_conf in image_configs.items():
Expand Down
51 changes: 43 additions & 8 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ class DoScheduleEvent(EmptyEventArgs, AbstractEvent):
name = "do_schedule"


class DoPrepareEvent(EmptyEventArgs, AbstractEvent):
name = "do_prepare"
class DoCheckPrecondEvent(EmptyEventArgs, AbstractEvent):
name = "do_check_precond"


class DoStartSessionEvent(EmptyEventArgs, AbstractEvent):
name = "do_start_session"


class DoScaleEvent(EmptyEventArgs, AbstractEvent):
Expand Down Expand Up @@ -211,27 +215,54 @@ def deserialize(cls, value: tuple):


@attrs.define(slots=True, frozen=True)
class ImagePullEventArgs:
class ImagePullStartedEvent(AbstractEvent):
name = "image_pull_started"

image: str = attrs.field()
agent_id: AgentId = attrs.field()
timestamp: float = attrs.field()

def serialize(self) -> tuple:
return (self.image, str(self.agent_id))
return (
self.image,
str(self.agent_id),
self.timestamp,
)

@classmethod
def deserialize(cls, value: tuple):
return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
)


class ImagePullStartedEvent(ImagePullEventArgs, AbstractEvent):
name = "image_pull_started"
@attrs.define(slots=True, frozen=True)
class ImagePullFinishedEvent(AbstractEvent):
name = "image_pull_finished"

image: str = attrs.field()
agent_id: AgentId = attrs.field()
timestamp: float = attrs.field()
msg: Optional[str] = attrs.field(default=None)

class ImagePullFinishedEvent(ImagePullEventArgs, AbstractEvent):
name = "image_pull_finished"
def serialize(self) -> tuple:
return (
self.image,
str(self.agent_id),
self.timestamp,
self.msg,
)

@classmethod
def deserialize(cls, value: tuple):
return cls(
image=value[0],
agent_id=AgentId(value[1]),
timestamp=value[2],
msg=value[3],
)


@attrs.define(slots=True, frozen=True)
Expand Down Expand Up @@ -459,6 +490,10 @@ class SessionScheduledEvent(SessionCreationEventArgs, AbstractEvent):
name = "session_scheduled"


class SessionCheckingPrecondEvent(SessionCreationEventArgs, AbstractEvent):
name = "session_checking_precondition"


class SessionPreparingEvent(SessionCreationEventArgs, AbstractEvent):
name = "session_preparing"

Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/manager/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ class ManagerStatus(enum.StrEnum):

class SchedulerEvent(enum.StrEnum):
SCHEDULE = "schedule"
PREPARE = "prepare"
CHECK_PRECOND = "check_precondition"
START_SESSION = "start_session"
SCALE_SERVICES = "scale_services"
13 changes: 10 additions & 3 deletions src/ai/backend/manager/api/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from ai.backend.common import redis_helper
from ai.backend.common import validators as tx
from ai.backend.common.events import DoPrepareEvent, DoScaleEvent, DoScheduleEvent
from ai.backend.common.events import (
DoCheckPrecondEvent,
DoScaleEvent,
DoScheduleEvent,
DoStartSessionEvent,
)
from ai.backend.common.types import PromMetric, PromMetricGroup, PromMetricPrimitive
from ai.backend.logging import BraceStyleAdapter

Expand Down Expand Up @@ -276,8 +281,10 @@ async def scheduler_trigger(request: web.Request, params: Any) -> web.Response:
match params["event"]:
case SchedulerEvent.SCHEDULE:
await root_ctx.event_producer.produce_event(DoScheduleEvent())
case SchedulerEvent.PREPARE:
await root_ctx.event_producer.produce_event(DoPrepareEvent())
case SchedulerEvent.CHECK_PRECOND:
await root_ctx.event_producer.produce_event(DoCheckPrecondEvent())
case SchedulerEvent.START_SESSION:
await root_ctx.event_producer.produce_event(DoStartSessionEvent())
case SchedulerEvent.SCALE_SERVICES:
await root_ctx.event_producer.produce_event(DoScaleEvent())
return web.Response(status=204)
Expand Down
7 changes: 5 additions & 2 deletions src/ai/backend/manager/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@
class LockID(enum.IntEnum):
LOCKID_TEST = 42
LOCKID_SCHEDULE = 91
LOCKID_PREPARE = 92
LOCKID_CHECK_PRECOND = 92
LOCKID_PREPARE = 93
LOCKID_SCHEDULE_TIMER = 191
LOCKID_PREPARE_TIMER = 192
LOCKID_CHECK_PRECOND_TIMER = 192
LOCKID_PREPARE_TIMER = 193
LOCKID_START_TIMER = 198
LOCKID_SCALE_TIMER = 193
LOCKID_LOG_CLEANUP_TIMER = 195
LOCKID_IDLE_CHECK_TIMER = 196
Expand Down
33 changes: 32 additions & 1 deletion src/ai/backend/manager/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ class ImageLoadFilter(enum.StrEnum):
"""Include every customized images filed at the system. Effective only for superadmin. CUSTOMIZED and CUSTOMIZED_GLOBAL are mutually exclusive."""


class RelationLoadingOption(enum.StrEnum):
ALIASES = enum.auto()
ENDPOINTS = enum.auto()
REGISTRY = enum.auto()


def _apply_loading_option(
query_stmt: sa.sql.Select, options: Iterable[RelationLoadingOption]
) -> sa.sql.Select:
for op in options:
match op:
case RelationLoadingOption.ALIASES:
query_stmt = query_stmt.options(selectinload(ImageRow.aliases))
case RelationLoadingOption.REGISTRY:
query_stmt = query_stmt.options(joinedload(ImageRow.registry_row))
case RelationLoadingOption.ENDPOINTS:
query_stmt = query_stmt.options(selectinload(ImageRow.endpoints))
return query_stmt


async def rescan_images(
db: ExtendedAsyncSAEngine,
registry_or_image: str | None = None,
Expand Down Expand Up @@ -277,6 +297,8 @@ async def from_alias(
session: AsyncSession,
alias: str,
load_aliases=False,
*,
loading_options: Iterable[RelationLoadingOption] = tuple(),
Comment on lines 299 to +301

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write a issue to refactor this bool flag

) -> ImageRow:
query = (
sa.select(ImageRow)
Expand All @@ -285,6 +307,7 @@ async def from_alias(
)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
query = _apply_loading_option(query, loading_options)
result = await session.scalar(query)
if result is not None:
return result
Expand All @@ -297,6 +320,8 @@ async def from_image_identifier(
session: AsyncSession,
identifier: ImageIdentifier,
load_aliases: bool = True,
*,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
query = sa.select(ImageRow).where(
(ImageRow.name == identifier.canonical)
Expand All @@ -305,6 +330,7 @@ async def from_image_identifier(

if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
query = _apply_loading_option(query, loading_options)

result = await session.execute(query)
candidates: List[ImageRow] = result.scalars().all()
Expand All @@ -322,6 +348,7 @@ async def from_image_ref(
*,
strict_arch: bool = False,
load_aliases: bool = False,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
"""
Loads a image row that corresponds to the given ImageRef object.
Expand All @@ -333,6 +360,7 @@ async def from_image_ref(
query = sa.select(ImageRow).where(ImageRow.name == ref.canonical)
if load_aliases:
query = query.options(selectinload(ImageRow.aliases))
query = _apply_loading_option(query, loading_options)

result = await session.execute(query)
candidates: List[ImageRow] = result.scalars().all()
Expand All @@ -354,6 +382,7 @@ async def resolve(
*,
strict_arch: bool = False,
load_aliases: bool = True,
loading_options: Iterable[RelationLoadingOption] = tuple(),
) -> ImageRow:
"""
Resolves a matching row in the image table from image references and/or aliases.
Expand Down Expand Up @@ -401,7 +430,9 @@ async def resolve(
resolver_func = cls.from_image_identifier
searched_refs.append(f"identifier:{reference!r}")
try:
if row := await resolver_func(session, reference, load_aliases=load_aliases):
if row := await resolver_func(
session, reference, load_aliases=load_aliases, loading_options=loading_options
):
return row
except UnknownImageReference:
continue
Expand Down
11 changes: 4 additions & 7 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,10 @@ def set_status(
self.terminated_at = now
self.status_changed = now
self.status = status
self.status_history = sql_json_merge(
KernelRow.status_history,
(),
{
status.name: now.isoformat(),
},
)
self.status_history = {
**self.status_history,
status.name: now.isoformat(),
}
if status_info is not None:
self.status_info = status_info
if status_data is not None:
Expand Down
19 changes: 18 additions & 1 deletion src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from sqlalchemy.dialects import postgresql as pgsql
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, relationship, selectinload
from sqlalchemy.orm import joinedload, load_only, noload, relationship, selectinload

from ai.backend.common import redis_helper
from ai.backend.common.events import (
Expand Down Expand Up @@ -87,6 +87,7 @@
batch_result_in_session,
)
from .group import GroupRow
from .image import ImageRow
from .kernel import ComputeContainer, KernelRow, KernelStatus
from .minilang import ArrayFieldItem, JSONFieldItem
from .minilang.ordering import ColumnMapType, QueryOrderParser
Expand Down Expand Up @@ -858,6 +859,22 @@ async def get_session_id_by_kernel(
async with db.begin_readonly_session() as db_session:
return await db_session.scalar(query)

@classmethod
async def get_sessions_by_status(
cls,
db_session: SASession,
status: SessionStatus,
*,
load_kernel_image: bool = False,
) -> list[SessionRow]:
load_options = selectinload(SessionRow.kernels)
if load_kernel_image:
load_options = load_options.options(
joinedload(KernelRow.image_row).options(joinedload(ImageRow.registry_row))
)
stmt = sa.select(SessionRow).where(SessionRow.status == status).options(load_options)
return (await db_session.scalars(stmt)).all()

@classmethod
async def get_session_to_determine_status(
cls, db_session: SASession, session_id: SessionId
Expand Down
Loading