Skip to content

Commit

Permalink
dont pass db cxn to sync_agent_resource() and fix wrong types
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jul 12, 2024
1 parent 0baf994 commit 17ac6f1
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 63 deletions.
12 changes: 6 additions & 6 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,10 +485,10 @@ async def sync_kernel_registry(
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Collection[str],
pulling_kernels: Collection[str],
running_kernels: Collection[str],
terminating_kernels: Collection[str],
preparing_kernels: Collection[UUID],
pulling_kernels: Collection[UUID],
running_kernels: Collection[UUID],
terminating_kernels: Collection[UUID],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth data
Expand All @@ -506,7 +506,7 @@ async def sync_and_get_kernels(
]

for raw_kernel_id in running_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id in self.agent.terminating_kernels:
actual_terminating_kernels.append((
Expand All @@ -523,7 +523,7 @@ async def sync_and_get_kernels(
))

for raw_kernel_id in terminating_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def as_trafaret(cls) -> t.Trafaret:
from . import validators as tx

return t.Dict({
t.Key("actual_existing_kernels"): tx.ToList(t.String),
t.Key("actual_existing_kernels"): tx.ToList(tx.UUID),
t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
})
23 changes: 16 additions & 7 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from ai.backend.common.types import (
AccessKey,
AgentId,
AgentKernelRegistryByStatus,
ClusterMode,
ImageRegistry,
KernelId,
Expand All @@ -86,6 +87,7 @@

from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE
from ..exceptions import MultiAgentError
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
Expand Down Expand Up @@ -993,13 +995,20 @@ async def sync_agent_resource(
"SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id
)

async with root_ctx.db.begin() as db_conn:
try:
await root_ctx.registry.sync_agent_resource(db_conn, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
return web.Response(status=204)
try:
result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
val = result.get(agent_id)
match val:
case AgentKernelRegistryByStatus():
pass
case MultiAgentError():
return web.Response(status=500)
case _:
pass
return web.Response(status=204)


@server_status_required(ALL_ALLOWED)
Expand Down
40 changes: 21 additions & 19 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from redis.asyncio import Redis
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, selectinload
from sqlalchemy.orm.exc import NoResultFound
from yarl import URL
Expand Down Expand Up @@ -3161,7 +3160,7 @@ async def _sync_agent_resource_and_get_kerenels(

async def sync_agent_resource(
self,
db_connection: SAConnection,
db: ExtendedAsyncSAEngine,
agent_ids: Collection[AgentId],
) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]:
result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {}
Expand All @@ -3182,7 +3181,7 @@ async def sync_agent_resource(
).options(load_only(KernelRow.id, KernelRow.status))
)
)
async with SASession(bind=db_connection) as db_session:
async with db.begin_readonly_session() as db_session:
for _agent_row in await db_session.scalars(stmt):
agent_row = cast(AgentRow, _agent_row)
preparing_kernels: list[KernelId] = []
Expand All @@ -3208,24 +3207,27 @@ async def sync_agent_resource(
"running_kernels": running_kernels,
"terminating_kernels": terminating_kernels,
}
tasks = []
for agent_id in agent_ids:
tasks.append(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
aid_task_list: list[tuple[AgentId, asyncio.Task]] = []
async with aiotools.PersistentTaskGroup() as tg:
for agent_id in agent_ids:
task = tg.create_task(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
)
)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for aid, resp in zip(agent_ids, responses):
aid_task_list.append((agent_id, task))
for aid, task in aid_task_list:
agent_errors = []
if isinstance(resp, aiotools.TaskGroupError):
agent_errors.extend(resp.__errors__)
elif isinstance(result, Exception):
agent_errors.append(resp)
try:
resp = await task
except aiotools.TaskGroupError as e:
agent_errors.extend(e.__errors__)
except Exception as e:
agent_errors.append(e)
if agent_errors:
result[aid] = MultiAgentError(
"agent(s) raise errors during kernel resource sync",
Expand Down
56 changes: 26 additions & 30 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,37 +306,33 @@ def _pipeline(r: Redis) -> RedisPipeline:
result = await db_sess.execute(query)
schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()]

async with self.db.begin() as db_conn:
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
),
)
except Exception as e:
log.exception(
"schedule({}): scheduling error!\n{}", sgroup_name, repr(e)
)
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING
in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(db_conn, selected_agent_ids)
),
)
except Exception as e:
log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e))
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(self.db, selected_agent_ids)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down

0 comments on commit 17ac6f1

Please sign in to comment.