diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 2d0100ea99f..badab3d5549 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -485,10 +485,10 @@ async def sync_kernel_registry( @collect_error async def sync_and_get_kernels( self, - preparing_kernels: Collection[str], - pulling_kernels: Collection[str], - running_kernels: Collection[str], - terminating_kernels: Collection[str], + preparing_kernels: Collection[UUID], + pulling_kernels: Collection[UUID], + running_kernels: Collection[UUID], + terminating_kernels: Collection[UUID], ) -> dict[str, Any]: """ Sync kernel_registry and containers to truth data @@ -506,7 +506,7 @@ async def sync_and_get_kernels( ] for raw_kernel_id in running_kernels: - kernel_id = KernelId(UUID(raw_kernel_id)) + kernel_id = KernelId(raw_kernel_id) if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: if kernel_id in self.agent.terminating_kernels: actual_terminating_kernels.append(( @@ -523,7 +523,7 @@ async def sync_and_get_kernels( )) for raw_kernel_id in terminating_kernels: - kernel_id = KernelId(UUID(raw_kernel_id)) + kernel_id = KernelId(raw_kernel_id) if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: if kernel_id not in self.agent.terminating_kernels: await self.agent.inject_container_lifecycle_event( diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 183c16ad2fa..6b74d9310eb 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1301,7 +1301,7 @@ def as_trafaret(cls) -> t.Trafaret: from . import validators as tx return t.Dict({ - t.Key("actual_existing_kernels"): tx.ToList(t.String), + t.Key("actual_existing_kernels"): tx.ToList(tx.UUID), t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)), t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)), }) diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 184d4db7fb2..6738e435e67 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -75,6 +75,7 @@ from ai.backend.common.types import ( AccessKey, AgentId, + AgentKernelRegistryByStatus, ClusterMode, ImageRegistry, KernelId, @@ -86,6 +87,7 @@ from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ..exceptions import MultiAgentError from ..models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, DEAD_SESSION_STATUSES, @@ -993,13 +995,20 @@ async def sync_agent_resource( "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id ) - async with root_ctx.db.begin() as db_conn: - try: - await root_ctx.registry.sync_agent_resource(db_conn, [agent_id]) - except BackendError: - log.exception("SYNC_AGENT_RESOURCE: exception") - raise - return web.Response(status=204) + try: + result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + val = result.get(agent_id) + match val: + case AgentKernelRegistryByStatus(): + pass + case MultiAgentError(): + return web.Response(status=500) + case _: + pass + return web.Response(status=204) @server_status_required(ALL_ALLOWED) diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index ea8d62b9505..4c7d280de92 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -47,7 +47,6 @@ from redis.asyncio import Redis from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import load_only, noload, selectinload from sqlalchemy.orm.exc import NoResultFound from yarl import URL @@ -3161,7 +3160,7 @@ async def _sync_agent_resource_and_get_kerenels( async def sync_agent_resource( self, - db_connection: SAConnection, + db: ExtendedAsyncSAEngine, agent_ids: Collection[AgentId], ) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]: result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {} @@ -3182,7 +3181,7 @@ async def sync_agent_resource( ).options(load_only(KernelRow.id, KernelRow.status)) ) ) - async with SASession(bind=db_connection) as db_session: + async with db.begin_readonly_session() as db_session: for _agent_row in await db_session.scalars(stmt): agent_row = cast(AgentRow, _agent_row) preparing_kernels: list[KernelId] = [] @@ -3208,24 +3207,27 @@ async def sync_agent_resource( "running_kernels": running_kernels, "terminating_kernels": terminating_kernels, } - tasks = [] - for agent_id in agent_ids: - tasks.append( - self._sync_agent_resource_and_get_kerenels( - agent_id, - agent_kernel_by_status[agent_id]["preparing_kernels"], - agent_kernel_by_status[agent_id]["pulling_kernels"], - agent_kernel_by_status[agent_id]["running_kernels"], - agent_kernel_by_status[agent_id]["terminating_kernels"], + aid_task_list: list[tuple[AgentId, asyncio.Task]] = [] + async with aiotools.PersistentTaskGroup() as tg: + for agent_id in agent_ids: + task = tg.create_task( + self._sync_agent_resource_and_get_kerenels( + agent_id, + agent_kernel_by_status[agent_id]["preparing_kernels"], + agent_kernel_by_status[agent_id]["pulling_kernels"], + agent_kernel_by_status[agent_id]["running_kernels"], + agent_kernel_by_status[agent_id]["terminating_kernels"], + ) ) - ) - responses = await asyncio.gather(*tasks, return_exceptions=True) - for aid, resp in zip(agent_ids, responses): + aid_task_list.append((agent_id, task)) + for aid, task in aid_task_list: agent_errors = [] - if isinstance(resp, aiotools.TaskGroupError): - agent_errors.extend(resp.__errors__) - elif isinstance(result, Exception): - agent_errors.append(resp) + try: + resp = await task + except aiotools.TaskGroupError as e: + agent_errors.extend(e.__errors__) + except Exception as e: + agent_errors.append(e) if agent_errors: result[aid] = MultiAgentError( "agent(s) raise errors during kernel resource sync", diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index aa3a52d77cd..4a60fa365f2 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -306,37 +306,33 @@ def _pipeline(r: Redis) -> RedisPipeline: result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - async with self.db.begin() as db_conn: - for sgroup_name in schedulable_scaling_groups: - try: - kernel_agent_bindings = await self._schedule_in_sgroup( - sched_ctx, + for sgroup_name in schedulable_scaling_groups: + try: + kernel_agent_bindings = await self._schedule_in_sgroup( + sched_ctx, + sgroup_name, + ) + await redis_helper.execute( + self.redis_live, + lambda r: r.hset( + redis_key, + "resource_group", sgroup_name, - ) - await redis_helper.execute( - self.redis_live, - lambda r: r.hset( - redis_key, - "resource_group", - sgroup_name, - ), - ) - except Exception as e: - log.exception( - "schedule({}): scheduling error!\n{}", sgroup_name, repr(e) - ) - else: - if ( - AgentResourceSyncTrigger.AFTER_SCHEDULING - in agent_resource_sync_trigger - and kernel_agent_bindings - ): - selected_agent_ids = [ - binding.agent_alloc_ctx.agent_id - for binding in kernel_agent_bindings - if binding.agent_alloc_ctx.agent_id is not None - ] - await self.registry.sync_agent_resource(db_conn, selected_agent_ids) + ), + ) + except Exception as e: + log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + await self.registry.sync_agent_resource(self.db, selected_agent_ids) await redis_helper.execute( self.redis_live, lambda r: r.hset(