Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: agent resource sync API #2180

Draft
wants to merge 6 commits into
base: topic/07-22-feat_schedule_function_returns_kernel-agent_binding
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum:
self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value)


class EnumList(t.Trafaret, Generic[T_enum]):
def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None:
self.enum_cls = enum_cls
self.use_name = use_name

def check_and_return(self, value: Any) -> list[T_enum]:
try:
if self.use_name:
return [self.enum_cls[val] for val in value]
else:
return [self.enum_cls(val) for val in value]
except TypeError:
self._failure("cannot parse value into list", value=value)
except (KeyError, ValueError):
self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value)


class JSONString(t.Trafaret):
def check_and_return(self, value: Any) -> dict:
try:
Expand Down
46 changes: 45 additions & 1 deletion src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
import trafaret as t
from aiohttp import hdrs, web
from dateutil.tz import tzutc
from pydantic import AliasChoices, BaseModel, Field
from pydantic import (
AliasChoices,
BaseModel,
Field,
)
from redis.asyncio import Redis
from sqlalchemy.orm import noload, selectinload
from sqlalchemy.sql.expression import null, true
Expand Down Expand Up @@ -73,6 +77,7 @@
ClusterMode,
ImageRegistry,
KernelId,
KernelStatusCollection,
MountPermission,
MountTypes,
SessionTypes,
Expand All @@ -82,6 +87,7 @@

from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE
from ..exceptions import MultiAgentError
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
Expand Down Expand Up @@ -969,6 +975,43 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe
return web.json_response({}, status=200)


class SyncAgentResourceRequestModel(BaseModel):
agent_id: AgentId = Field(
validation_alias=AliasChoices("agent_id", "agent"),
description="Target agent id to sync resource.",
)


@server_status_required(ALL_ALLOWED)
@auth_required
@pydantic_params_api_handler(SyncAgentResourceRequestModel)
async def sync_agent_resource(
request: web.Request, params: SyncAgentResourceRequestModel
) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
requester_access_key, owner_access_key = await get_access_key_scopes(request)

agent_id = params.agent_id
log.info(
"SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id
)

try:
result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
val = result.get(agent_id)
match val:
case KernelStatusCollection():
pass
case MultiAgentError():
return web.Response(status=500)
case _:
pass
return web.Response(status=204)


@server_status_required(ALL_ALLOWED)
@auth_required
@check_api_params(
Expand Down Expand Up @@ -2315,6 +2358,7 @@ def create_app(
cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster))
cors.add(app.router.add_route("GET", "/_/match", match_sessions))
cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry))
cors.add(app.router.add_route("POST", "/_/sync-agent-resource", sync_agent_resource))
session_resource = cors.add(app.router.add_resource(r"/{session_name}"))
cors.add(session_resource.add_route("GET", get_info))
cors.add(session_resource.add_route("PATCH", restart))
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
from .api.exceptions import ObjectNotFound, ServerMisconfiguredError
from .models.session import SessionStatus
from .pglock import PgAdvisoryLock
from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger

log = BraceStyleAdapter(logging.getLogger(__spec__.name))

Expand Down Expand Up @@ -295,6 +296,9 @@
"agent-selection-resource-priority",
default=["cuda", "rocm", "tpu", "cpu", "mem"],
): t.List(t.String),
t.Key(
"agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS
): tx.EnumList(AgentResourceSyncTrigger),
t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String,
t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB
tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[
Expand Down
23 changes: 21 additions & 2 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
)
from .config import LocalConfig, SharedConfig
from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS
from .exceptions import MultiAgentError, convert_to_status_data
from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data
from .models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES,
Expand Down Expand Up @@ -182,7 +182,7 @@
reenter_txn_session,
sql_json_merge,
)
from .types import UserScope
from .types import AgentResourceSyncTrigger, UserScope

if TYPE_CHECKING:
from sqlalchemy.engine.row import Row
Expand Down Expand Up @@ -1694,6 +1694,10 @@ async def _create_kernels_in_one_agent(
is_local = image_info["is_local"]
resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"]
auto_pull = image_info["auto_pull"]
agent_resource_sync_trigger = cast(
list[AgentResourceSyncTrigger],
self.local_config["manager"]["agent-resource-sync-trigger"],
)
assert agent_alloc_ctx.agent_id is not None
assert scheduled_session.id is not None

Expand Down Expand Up @@ -1790,6 +1794,9 @@ async def _update_kernel() -> None:
ex = e
err_info = convert_to_status_data(ex, self.debug)

def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool:
return err_info["error"]["name"] == "InsufficientResource"

# The agent has already cancelled or issued the destruction lifecycle event
# for this batch of kernels.
for binding in items:
Expand Down Expand Up @@ -1821,6 +1828,18 @@ async def _update_failure() -> None:
await db_sess.execute(query)

await execute_with_retry(_update_failure)
if (
AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger
and _is_insufficient_resource_err(err_info)
):
await self.sync_agent_resource(
self.db,
[
binding.agent_alloc_ctx.agent_id
for binding in items
if binding.agent_alloc_ctx.agent_id is not None
],
)
raise

async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair:
Expand Down
24 changes: 19 additions & 5 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Sequence,
Tuple,
Union,
cast,
)

import aiotools
Expand Down Expand Up @@ -94,6 +95,7 @@
)
from ..models.utils import ExtendedAsyncSAEngine as SAEngine
from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge
from ..types import AgentResourceSyncTrigger
from .predicates import (
check_concurrency,
check_dependencies,
Expand Down Expand Up @@ -265,6 +267,10 @@ async def schedule(
log.debug("schedule(): triggered")
manager_id = self.local_config["manager"]["id"]
redis_key = f"manager.{manager_id}.schedule"
agent_resource_sync_trigger = cast(
list[AgentResourceSyncTrigger],
self.local_config["manager"]["agent-resource-sync-trigger"],
)

def _pipeline(r: Redis) -> RedisPipeline:
pipe = r.pipeline()
Expand Down Expand Up @@ -293,20 +299,17 @@ def _pipeline(r: Redis) -> RedisPipeline:
# as its individual steps are composed of many short-lived transactions.
async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60):
async with self.db.begin_readonly_session() as db_sess:
# query = (
# sa.select(ScalingGroupRow)
# .join(ScalingGroupRow.agents.and_(AgentRow.status == AgentStatus.ALIVE))
# )
query = (
sa.select(AgentRow.scaling_group)
.where(AgentRow.status == AgentStatus.ALIVE)
.group_by(AgentRow.scaling_group)
)
result = await db_sess.execute(query)
schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()]

for sgroup_name in schedulable_scaling_groups:
try:
await self._schedule_in_sgroup(
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
sgroup_name,
)
Expand All @@ -320,6 +323,17 @@ def _pipeline(r: Redis) -> RedisPipeline:
)
except Exception as e:
log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e))
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(self.db, selected_agent_ids)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down
11 changes: 11 additions & 0 deletions src/ai/backend/manager/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,14 @@ class MountOptionModel(BaseModel):
MountPermission | None,
Field(validation_alias=AliasChoices("permission", "perm"), default=None),
]


class AgentResourceSyncTrigger(enum.StrEnum):
AFTER_SCHEDULING = "after-scheduling"
BEFORE_KERNEL_CREATION = "before-kernel-creation"
ON_CREATION_FAILURE = "on-creation-failure"


DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [
AgentResourceSyncTrigger.ON_CREATION_FAILURE
]
Loading