Skip to content

Commit

Permalink
dont pass db cxn to sync_agent_resource() and fix wrong types
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jul 22, 2024
1 parent 499ced2 commit 0594b41
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 79 deletions.
30 changes: 14 additions & 16 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from collections.abc import Collection
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from pathlib import Path
Expand Down Expand Up @@ -83,7 +82,7 @@
)
from .exception import ResourceError
from .monitor import AgentErrorPluginContext, AgentStatsPluginContext
from .types import AgentBackend, LifecycleEvent, VolumeInfo
from .types import AgentBackend, KernelLifecycleStatus, LifecycleEvent, VolumeInfo
from .utils import get_arch_name, get_subnet_ip

if TYPE_CHECKING:
Expand Down Expand Up @@ -485,10 +484,10 @@ async def sync_kernel_registry(
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Collection[str],
pulling_kernels: Collection[str],
running_kernels: Collection[str],
terminating_kernels: Collection[str],
preparing_kernels: Iterable[UUID],
pulling_kernels: Iterable[UUID],
running_kernels: Iterable[UUID],
terminating_kernels: Iterable[UUID],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth data
Expand All @@ -501,14 +500,14 @@ async def sync_and_get_kernels(
async with self.agent.registry_lock:
actual_existing_kernels = [
kid
for kid in self.agent.kernel_registry
if kid not in self.agent.terminating_kernels
for kid, obj in self.agent.kernel_registry.items()
if obj.state == KernelLifecycleStatus.RUNNING
]

for raw_kernel_id in running_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
actual_terminating_kernels.append((
kernel_id,
str(
Expand All @@ -523,9 +522,9 @@ async def sync_and_get_kernels(
))

for raw_kernel_id in terminating_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id not in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand All @@ -542,7 +541,7 @@ async def sync_and_get_kernels(

for kernel_id, kernel_obj in self.agent.kernel_registry.items():
if kernel_id in terminating_kernels:
if kernel_id not in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand All @@ -554,16 +553,15 @@ async def sync_and_get_kernels(
elif kernel_id in running_kernels:
pass
elif kernel_id in preparing_kernels:
# kernel_registry may not have `preparing` state kernels.
pass
elif kernel_id in pulling_kernels:
# kernel_registry does not have `pulling` state kernels.
# Let's just skip it.
# Let's skip it.
pass
else:
# This kernel is not alive according to the truth data.
# The kernel should be destroyed or cleaned
if kernel_id in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def as_trafaret(cls) -> t.Trafaret:
from . import validators as tx

return t.Dict({
t.Key("actual_existing_kernels"): tx.ToList(t.String),
t.Key("actual_existing_kernels"): tx.ToList(tx.UUID),
t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
})
23 changes: 16 additions & 7 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from ai.backend.common.types import (
AccessKey,
AgentId,
AgentKernelRegistryByStatus,
ClusterMode,
ImageRegistry,
KernelId,
Expand All @@ -86,6 +87,7 @@

from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE
from ..exceptions import MultiAgentError
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
Expand Down Expand Up @@ -993,13 +995,20 @@ async def sync_agent_resource(
"SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id
)

async with root_ctx.db.begin() as db_conn:
try:
await root_ctx.registry.sync_agent_resource(db_conn, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
return web.Response(status=204)
try:
result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
val = result.get(agent_id)
match val:
case AgentKernelRegistryByStatus():
pass
case MultiAgentError():
return web.Response(status=500)
case _:
pass
return web.Response(status=204)


@server_status_required(ALL_ALLOWED)
Expand Down
67 changes: 42 additions & 25 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import uuid
import zlib
from collections import defaultdict
from collections.abc import Collection
from collections.abc import Collection, Iterable
from datetime import datetime
from decimal import Decimal
from io import BytesIO
Expand Down Expand Up @@ -47,7 +47,6 @@
from redis.asyncio import Redis
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria
from sqlalchemy.orm.exc import NoResultFound
from yarl import URL
Expand Down Expand Up @@ -131,7 +130,7 @@
)
from .config import LocalConfig, SharedConfig
from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS
from .exceptions import MultiAgentError, convert_to_status_data
from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data
from .models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES,
Expand Down Expand Up @@ -1796,6 +1795,9 @@ async def _update_kernel() -> None:
ex = e
err_info = convert_to_status_data(ex, self.debug)

def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool:
return err_info["error"]["name"] == "InsufficientResource"

# The agent has already cancelled or issued the destruction lifecycle event
# for this batch of kernels.
for binding in items:
Expand Down Expand Up @@ -1827,6 +1829,18 @@ async def _update_failure() -> None:
await db_sess.execute(query)

await execute_with_retry(_update_failure)
if (
AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger
and _is_insufficient_resource_err(err_info)
):
await self.sync_agent_resource(
self.db,
[
binding.agent_alloc_ctx.agent_id
for binding in items
if binding.agent_alloc_ctx.agent_id is not None
],
)
raise

async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair:
Expand Down Expand Up @@ -3195,10 +3209,10 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None:
async def _sync_agent_resource_and_get_kerenels(
self,
agent_id: AgentId,
preparing_kernels: Collection[KernelId],
pulling_kernels: Collection[KernelId],
running_kernels: Collection[KernelId],
terminating_kernels: Collection[KernelId],
preparing_kernels: Iterable[KernelId],
pulling_kernels: Iterable[KernelId],
running_kernels: Iterable[KernelId],
terminating_kernels: Iterable[KernelId],
) -> AgentKernelRegistryByStatus:
async with self.agent_cache.rpc_context(agent_id) as rpc:
resp: dict[str, Any] = await rpc.call.sync_and_get_kernels(
Expand All @@ -3211,7 +3225,7 @@ async def _sync_agent_resource_and_get_kerenels(

async def sync_agent_resource(
self,
db_connection: SAConnection,
db: ExtendedAsyncSAEngine,
agent_ids: Collection[AgentId],
) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]:
result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {}
Expand All @@ -3232,7 +3246,7 @@ async def sync_agent_resource(
).options(load_only(KernelRow.id, KernelRow.status))
)
)
async with SASession(bind=db_connection) as db_session:
async with db.begin_readonly_session() as db_session:
for _agent_row in await db_session.scalars(stmt):
agent_row = cast(AgentRow, _agent_row)
preparing_kernels: list[KernelId] = []
Expand All @@ -3258,24 +3272,27 @@ async def sync_agent_resource(
"running_kernels": running_kernels,
"terminating_kernels": terminating_kernels,
}
tasks = []
for agent_id in agent_ids:
tasks.append(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
aid_task_list: list[tuple[AgentId, asyncio.Task]] = []
async with aiotools.PersistentTaskGroup() as tg:
for agent_id in agent_ids:
task = tg.create_task(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
)
)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for aid, resp in zip(agent_ids, responses):
aid_task_list.append((agent_id, task))
for aid, task in aid_task_list:
agent_errors = []
if isinstance(resp, aiotools.TaskGroupError):
agent_errors.extend(resp.__errors__)
elif isinstance(result, Exception):
agent_errors.append(resp)
try:
resp = await task
except aiotools.TaskGroupError as e:
agent_errors.extend(e.__errors__)
except Exception as e:
agent_errors.append(e)
if agent_errors:
result[aid] = MultiAgentError(
"agent(s) raise errors during kernel resource sync",
Expand Down
56 changes: 26 additions & 30 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,37 +306,33 @@ def _pipeline(r: Redis) -> RedisPipeline:
result = await db_sess.execute(query)
schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()]

async with self.db.begin() as db_conn:
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
),
)
except Exception as e:
log.exception(
"schedule({}): scheduling error!\n{}", sgroup_name, repr(e)
)
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING
in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(db_conn, selected_agent_ids)
),
)
except Exception as e:
log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e))
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(self.db, selected_agent_ids)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down

0 comments on commit 0594b41

Please sign in to comment.