Skip to content

Commit

Permalink
dont pass db cxn to sync_agent_resource() and fix wrong types
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jul 19, 2024
1 parent 57cd115 commit 5c5886a
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 72 deletions.
26 changes: 13 additions & 13 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
)
from .exception import ResourceError
from .monitor import AgentErrorPluginContext, AgentStatsPluginContext
from .types import AgentBackend, LifecycleEvent, VolumeInfo
from .types import AgentBackend, KernelLifecycleStatus, LifecycleEvent, VolumeInfo
from .utils import get_arch_name, get_subnet_ip

if TYPE_CHECKING:
Expand Down Expand Up @@ -485,10 +485,10 @@ async def sync_kernel_registry(
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Collection[str],
pulling_kernels: Collection[str],
running_kernels: Collection[str],
terminating_kernels: Collection[str],
preparing_kernels: Collection[UUID],
pulling_kernels: Collection[UUID],
running_kernels: Collection[UUID],
terminating_kernels: Collection[UUID],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth data
Expand All @@ -501,14 +501,14 @@ async def sync_and_get_kernels(
async with self.agent.registry_lock:
actual_existing_kernels = [
kid
for kid in self.agent.kernel_registry
if kid not in self.agent.terminating_kernels
for kid, obj in self.agent.kernel_registry.items()
if obj.state == KernelLifecycleStatus.RUNNING
]

for raw_kernel_id in running_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
actual_terminating_kernels.append((
kernel_id,
str(
Expand All @@ -523,9 +523,9 @@ async def sync_and_get_kernels(
))

for raw_kernel_id in terminating_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
kernel_id = KernelId(raw_kernel_id)
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id not in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand All @@ -542,7 +542,7 @@ async def sync_and_get_kernels(

for kernel_id, kernel_obj in self.agent.kernel_registry.items():
if kernel_id in terminating_kernels:
if kernel_id not in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand All @@ -563,7 +563,7 @@ async def sync_and_get_kernels(
else:
# This kernel is not alive according to the truth data.
# The kernel should be destroyed or cleaned
if kernel_id in self.agent.terminating_kernels:
if kernel_obj.state == KernelLifecycleStatus.TERMINATING:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def as_trafaret(cls) -> t.Trafaret:
from . import validators as tx

return t.Dict({
t.Key("actual_existing_kernels"): tx.ToList(t.String),
t.Key("actual_existing_kernels"): tx.ToList(tx.UUID),
t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)),
})
23 changes: 16 additions & 7 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from ai.backend.common.types import (
AccessKey,
AgentId,
AgentKernelRegistryByStatus,
ClusterMode,
ImageRegistry,
KernelId,
Expand All @@ -86,6 +87,7 @@

from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE
from ..exceptions import MultiAgentError
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
Expand Down Expand Up @@ -993,13 +995,20 @@ async def sync_agent_resource(
"SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id
)

async with root_ctx.db.begin() as db_conn:
try:
await root_ctx.registry.sync_agent_resource(db_conn, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
return web.Response(status=204)
try:
result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
val = result.get(agent_id)
match val:
case AgentKernelRegistryByStatus():
pass
case MultiAgentError():
return web.Response(status=500)
case _:
pass
return web.Response(status=204)


@server_status_required(ALL_ALLOWED)
Expand Down
60 changes: 39 additions & 21 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from redis.asyncio import Redis
from sqlalchemy.exc import DBAPIError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession as SASession
from sqlalchemy.orm import load_only, noload, selectinload, with_loader_criteria
from sqlalchemy.orm.exc import NoResultFound
from yarl import URL
Expand Down Expand Up @@ -131,7 +130,7 @@
)
from .config import LocalConfig, SharedConfig
from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS
from .exceptions import MultiAgentError, convert_to_status_data
from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data
from .models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES,
Expand Down Expand Up @@ -1796,6 +1795,9 @@ async def _update_kernel() -> None:
ex = e
err_info = convert_to_status_data(ex, self.debug)

def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool:
return err_info["error"]["name"] == "InsufficientResource"

# The agent has already cancelled or issued the destruction lifecycle event
# for this batch of kernels.
for binding in items:
Expand Down Expand Up @@ -1824,9 +1826,22 @@ async def _update_failure() -> None:
status_data=err_info,
)
)
await db_sess.execute(query)

await db_sess.execute(query)

await execute_with_retry(_update_failure)
if (
AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger
and _is_insufficient_resource_err(err_info)
):
await self.sync_agent_resource(
self.db,
[
binding.agent_alloc_ctx.agent_id
for binding in items
if binding.agent_alloc_ctx.agent_id is not None
],
)
raise

async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair:
Expand Down Expand Up @@ -3211,7 +3226,7 @@ async def _sync_agent_resource_and_get_kerenels(

async def sync_agent_resource(
self,
db_connection: SAConnection,
db: ExtendedAsyncSAEngine,
agent_ids: Collection[AgentId],
) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]:
result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {}
Expand All @@ -3232,7 +3247,7 @@ async def sync_agent_resource(
).options(load_only(KernelRow.id, KernelRow.status))
)
)
async with SASession(bind=db_connection) as db_session:
async with db.begin_readonly_session() as db_session:
for _agent_row in await db_session.scalars(stmt):
agent_row = cast(AgentRow, _agent_row)
preparing_kernels: list[KernelId] = []
Expand All @@ -3258,24 +3273,27 @@ async def sync_agent_resource(
"running_kernels": running_kernels,
"terminating_kernels": terminating_kernels,
}
tasks = []
for agent_id in agent_ids:
tasks.append(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
aid_task_list: list[tuple[AgentId, asyncio.Task]] = []
async with aiotools.PersistentTaskGroup() as tg:
for agent_id in agent_ids:
task = tg.create_task(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
)
)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for aid, resp in zip(agent_ids, responses):
aid_task_list.append((agent_id, task))
for aid, task in aid_task_list:
agent_errors = []
if isinstance(resp, aiotools.TaskGroupError):
agent_errors.extend(resp.__errors__)
elif isinstance(result, Exception):
agent_errors.append(resp)
try:
resp = await task
except aiotools.TaskGroupError as e:
agent_errors.extend(e.__errors__)
except Exception as e:
agent_errors.append(e)
if agent_errors:
result[aid] = MultiAgentError(
"agent(s) raise errors during kernel resource sync",
Expand Down
56 changes: 26 additions & 30 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,37 +306,33 @@ def _pipeline(r: Redis) -> RedisPipeline:
result = await db_sess.execute(query)
schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()]

async with self.db.begin() as db_conn:
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
),
)
except Exception as e:
log.exception(
"schedule({}): scheduling error!\n{}", sgroup_name, repr(e)
)
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING
in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(db_conn, selected_agent_ids)
),
)
except Exception as e:
log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e))
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(self.db, selected_agent_ids)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down

0 comments on commit 5c5886a

Please sign in to comment.