From bbbd88eae56faaa17fcbb36456f15a13a94cf9d6 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Fri, 12 Jul 2024 18:00:59 +0900 Subject: [PATCH] dont pass db cxn to sync_agent_resource() and fix wrong types --- src/ai/backend/agent/server.py | 30 ++++----- src/ai/backend/common/types.py | 2 +- src/ai/backend/manager/api/session.py | 23 +++++-- src/ai/backend/manager/registry.py | 67 ++++++++++++------- .../backend/manager/scheduler/dispatcher.py | 56 +++++++--------- 5 files changed, 99 insertions(+), 79 deletions(-) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 39f631f50ea..91007cd4e08 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -12,7 +12,6 @@ import signal import sys from collections import OrderedDict, defaultdict -from collections.abc import Collection from ipaddress import _BaseAddress as BaseIPAddress from ipaddress import ip_network from pathlib import Path @@ -83,7 +82,7 @@ ) from .exception import ResourceError from .monitor import AgentErrorPluginContext, AgentStatsPluginContext -from .types import AgentBackend, LifecycleEvent, VolumeInfo +from .types import AgentBackend, KernelLifecycleStatus, LifecycleEvent, VolumeInfo from .utils import get_arch_name, get_subnet_ip if TYPE_CHECKING: @@ -485,10 +484,10 @@ async def sync_kernel_registry( @collect_error async def sync_and_get_kernels( self, - preparing_kernels: Collection[str], - pulling_kernels: Collection[str], - running_kernels: Collection[str], - terminating_kernels: Collection[str], + preparing_kernels: Iterable[UUID], + pulling_kernels: Iterable[UUID], + running_kernels: Iterable[UUID], + terminating_kernels: Iterable[UUID], ) -> dict[str, Any]: """ Sync kernel_registry and containers to truth data @@ -501,14 +500,14 @@ async def sync_and_get_kernels( async with self.agent.registry_lock: actual_existing_kernels = [ kid - for kid in self.agent.kernel_registry - if kid not in self.agent.terminating_kernels + for kid, obj in self.agent.kernel_registry.items() + if obj.state == KernelLifecycleStatus.RUNNING ] for raw_kernel_id in running_kernels: - kernel_id = KernelId(UUID(raw_kernel_id)) + kernel_id = KernelId(raw_kernel_id) if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: - if kernel_id in self.agent.terminating_kernels: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: actual_terminating_kernels.append(( kernel_id, str( @@ -523,9 +522,9 @@ async def sync_and_get_kernels( )) for raw_kernel_id in terminating_kernels: - kernel_id = KernelId(UUID(raw_kernel_id)) + kernel_id = KernelId(raw_kernel_id) if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: - if kernel_id not in self.agent.terminating_kernels: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: await self.agent.inject_container_lifecycle_event( kernel_id, kernel_obj.session_id, @@ -542,7 +541,7 @@ async def sync_and_get_kernels( for kernel_id, kernel_obj in self.agent.kernel_registry.items(): if kernel_id in terminating_kernels: - if kernel_id not in self.agent.terminating_kernels: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: await self.agent.inject_container_lifecycle_event( kernel_id, kernel_obj.session_id, @@ -554,16 +553,15 @@ async def sync_and_get_kernels( elif kernel_id in running_kernels: pass elif kernel_id in preparing_kernels: - # kernel_registry may not have `preparing` state kernels. pass elif kernel_id in pulling_kernels: # kernel_registry does not have `pulling` state kernels. - # Let's just skip it. + # Let's skip it. pass else: # This kernel is not alive according to the truth data. # The kernel should be destroyed or cleaned - if kernel_id in self.agent.terminating_kernels: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: await self.agent.inject_container_lifecycle_event( kernel_id, kernel_obj.session_id, diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 183c16ad2fa..6b74d9310eb 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1301,7 +1301,7 @@ def as_trafaret(cls) -> t.Trafaret: from . import validators as tx return t.Dict({ - t.Key("actual_existing_kernels"): tx.ToList(t.String), + t.Key("actual_existing_kernels"): tx.ToList(tx.UUID), t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)), t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)), }) diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 184d4db7fb2..6738e435e67 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -75,6 +75,7 @@ from ai.backend.common.types import ( AccessKey, AgentId, + AgentKernelRegistryByStatus, ClusterMode, ImageRegistry, KernelId, @@ -86,6 +87,7 @@ from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ..exceptions import MultiAgentError from ..models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, DEAD_SESSION_STATUSES, @@ -993,13 +995,20 @@ async def sync_agent_resource( "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id ) - async with root_ctx.db.begin() as db_conn: - try: - await root_ctx.registry.sync_agent_resource(db_conn, [agent_id]) - except BackendError: - log.exception("SYNC_AGENT_RESOURCE: exception") - raise - return web.Response(status=204) + try: + result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + val = result.get(agent_id) + match val: + case AgentKernelRegistryByStatus(): + pass + case MultiAgentError(): + return web.Response(status=500) + case _: + pass + return web.Response(status=204) @server_status_required(ALL_ALLOWED) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 50e6b486f45..047dd95fb27 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -12,7 +12,7 @@ import uuid import zlib from collections import defaultdict -from collections.abc import Collection +from collections.abc import Collection, Iterable from datetime import datetime from decimal import Decimal from io import BytesIO @@ -47,7 +47,6 @@ from redis.asyncio import Redis from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria from sqlalchemy.orm.exc import NoResultFound from yarl import URL @@ -131,7 +130,7 @@ ) from .config import LocalConfig, SharedConfig from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS -from .exceptions import MultiAgentError, convert_to_status_data +from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data from .models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES, @@ -1796,6 +1795,9 @@ async def _update_kernel() -> None: ex = e err_info = convert_to_status_data(ex, self.debug) + def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool: + return err_info["error"]["name"] == "InsufficientResource" + # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. for binding in items: @@ -1827,6 +1829,18 @@ async def _update_failure() -> None: await db_sess.execute(query) await execute_with_retry(_update_failure) + if ( + AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger + and _is_insufficient_resource_err(err_info) + ): + await self.sync_agent_resource( + self.db, + [ + binding.agent_alloc_ctx.agent_id + for binding in items + if binding.agent_alloc_ctx.agent_id is not None + ], + ) raise async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: @@ -3195,10 +3209,10 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None: async def _sync_agent_resource_and_get_kerenels( self, agent_id: AgentId, - preparing_kernels: Collection[KernelId], - pulling_kernels: Collection[KernelId], - running_kernels: Collection[KernelId], - terminating_kernels: Collection[KernelId], + preparing_kernels: Iterable[KernelId], + pulling_kernels: Iterable[KernelId], + running_kernels: Iterable[KernelId], + terminating_kernels: Iterable[KernelId], ) -> AgentKernelRegistryByStatus: async with self.agent_cache.rpc_context(agent_id) as rpc: resp: dict[str, Any] = await rpc.call.sync_and_get_kernels( @@ -3211,7 +3225,7 @@ async def _sync_agent_resource_and_get_kerenels( async def sync_agent_resource( self, - db_connection: SAConnection, + db: ExtendedAsyncSAEngine, agent_ids: Collection[AgentId], ) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]: result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {} @@ -3232,7 +3246,7 @@ async def sync_agent_resource( ).options(load_only(KernelRow.id, KernelRow.status)) ) ) - async with SASession(bind=db_connection) as db_session: + async with db.begin_readonly_session() as db_session: for _agent_row in await db_session.scalars(stmt): agent_row = cast(AgentRow, _agent_row) preparing_kernels: list[KernelId] = [] @@ -3258,24 +3272,27 @@ async def sync_agent_resource( "running_kernels": running_kernels, "terminating_kernels": terminating_kernels, } - tasks = [] - for agent_id in agent_ids: - tasks.append( - self._sync_agent_resource_and_get_kerenels( - agent_id, - agent_kernel_by_status[agent_id]["preparing_kernels"], - agent_kernel_by_status[agent_id]["pulling_kernels"], - agent_kernel_by_status[agent_id]["running_kernels"], - agent_kernel_by_status[agent_id]["terminating_kernels"], + aid_task_list: list[tuple[AgentId, asyncio.Task]] = [] + async with aiotools.PersistentTaskGroup() as tg: + for agent_id in agent_ids: + task = tg.create_task( + self._sync_agent_resource_and_get_kerenels( + agent_id, + agent_kernel_by_status[agent_id]["preparing_kernels"], + agent_kernel_by_status[agent_id]["pulling_kernels"], + agent_kernel_by_status[agent_id]["running_kernels"], + agent_kernel_by_status[agent_id]["terminating_kernels"], + ) ) - ) - responses = await asyncio.gather(*tasks, return_exceptions=True) - for aid, resp in zip(agent_ids, responses): + aid_task_list.append((agent_id, task)) + for aid, task in aid_task_list: agent_errors = [] - if isinstance(resp, aiotools.TaskGroupError): - agent_errors.extend(resp.__errors__) - elif isinstance(result, Exception): - agent_errors.append(resp) + try: + resp = await task + except aiotools.TaskGroupError as e: + agent_errors.extend(e.__errors__) + except Exception as e: + agent_errors.append(e) if agent_errors: result[aid] = MultiAgentError( "agent(s) raise errors during kernel resource sync", diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 380f60085e9..b1677f8d700 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -306,37 +306,33 @@ def _pipeline(r: Redis) -> RedisPipeline: result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - async with self.db.begin() as db_conn: - for sgroup_name in schedulable_scaling_groups: - try: - kernel_agent_bindings = await self._schedule_in_sgroup( - sched_ctx, + for sgroup_name in schedulable_scaling_groups: + try: + kernel_agent_bindings = await self._schedule_in_sgroup( + sched_ctx, + sgroup_name, + ) + await redis_helper.execute( + self.redis_live, + lambda r: r.hset( + redis_key, + "resource_group", sgroup_name, - ) - await redis_helper.execute( - self.redis_live, - lambda r: r.hset( - redis_key, - "resource_group", - sgroup_name, - ), - ) - except Exception as e: - log.exception( - "schedule({}): scheduling error!\n{}", sgroup_name, repr(e) - ) - else: - if ( - AgentResourceSyncTrigger.AFTER_SCHEDULING - in agent_resource_sync_trigger - and kernel_agent_bindings - ): - selected_agent_ids = [ - binding.agent_alloc_ctx.agent_id - for binding in kernel_agent_bindings - if binding.agent_alloc_ctx.agent_id is not None - ] - await self.registry.sync_agent_resource(db_conn, selected_agent_ids) + ), + ) + except Exception as e: + log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + await self.registry.sync_agent_resource(self.db, selected_agent_ids) await redis_helper.execute( self.redis_live, lambda r: r.hset(