Skip to content

Commit

Permalink
dont pass db cxn to sync_agent_resource() and fix wrong types
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Aug 8, 2024
1 parent 70e7ad3 commit caa2113
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 56 deletions.
23 changes: 16 additions & 7 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from ai.backend.common.types import (
AccessKey,
AgentId,
AgentKernelRegistryByStatus,
ClusterMode,
ImageRegistry,
KernelId,
Expand All @@ -86,6 +87,7 @@

from ..config import DEFAULT_CHUNK_SIZE
from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE
from ..exceptions import MultiAgentError
from ..models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
DEAD_SESSION_STATUSES,
Expand Down Expand Up @@ -994,13 +996,20 @@ async def sync_agent_resource(
"SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id
)

async with root_ctx.db.begin() as db_conn:
try:
await root_ctx.registry.sync_agent_resource(db_conn, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
return web.Response(status=204)
try:
result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id])
except BackendError:
log.exception("SYNC_AGENT_RESOURCE: exception")
raise
val = result.get(agent_id)
match val:
case AgentKernelRegistryByStatus():
pass
case MultiAgentError():
return web.Response(status=500)
case _:
pass
return web.Response(status=204)


@server_status_required(ALL_ALLOWED)
Expand Down
17 changes: 16 additions & 1 deletion src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@
)
from .config import LocalConfig, SharedConfig
from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS
from .exceptions import MultiAgentError, convert_to_status_data
from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data
from .models import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES,
Expand Down Expand Up @@ -1794,6 +1794,9 @@ async def _update_kernel() -> None:
ex = e
err_info = convert_to_status_data(ex, self.debug)

def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool:
return err_info["error"]["name"] == "InsufficientResource"

# The agent has already cancelled or issued the destruction lifecycle event
# for this batch of kernels.
for binding in items:
Expand Down Expand Up @@ -1825,6 +1828,18 @@ async def _update_failure() -> None:
await db_sess.execute(query)

await execute_with_retry(_update_failure)
if (
AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger
and _is_insufficient_resource_err(err_info)
):
await self.sync_agent_resource(
self.db,
[
binding.agent_alloc_ctx.agent_id
for binding in items
if binding.agent_alloc_ctx.agent_id is not None
],
)
raise

async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair:
Expand Down
75 changes: 27 additions & 48 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from ai.backend.common.plugin.hook import PASSED, HookResult
from ai.backend.common.types import (
AgentId,
AgentKernelRegistryByStatus,
ClusterMode,
RedisConnectionInfo,
ResourceSlot,
Expand All @@ -76,7 +75,7 @@
SessionNotFound,
)
from ..defs import SERVICE_MAX_RETRIES, LockID
from ..exceptions import MultiAgentError, convert_to_status_data
from ..exceptions import convert_to_status_data
from ..models import (
AgentRow,
AgentStatus,
Expand Down Expand Up @@ -308,47 +307,33 @@ def _pipeline(r: Redis) -> RedisPipeline:
result = await db_sess.execute(query)
schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()]

async with self.db.begin() as db_conn:
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
for sgroup_name in schedulable_scaling_groups:
try:
kernel_agent_bindings = await self._schedule_in_sgroup(
sched_ctx,
sgroup_name,
)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
redis_key,
"resource_group",
sgroup_name,
)
except Exception as e:
log.exception(
"schedule({}): scheduling error!\n{}", sgroup_name, repr(e)
)
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING
in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
async with self.db.begin() as db_conn:
results = await self.registry.sync_agent_resource(
self.db, selected_agent_ids
)
for agent_id, result in results.items():
match result:
case AgentKernelRegistryByStatus(
all_running_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
):
pass
case MultiAgentError():
pass
case _:
pass
pass
async with SASession(bind=db_conn) as db_session:
pass
),
)
except Exception as e:
log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e))
else:
if (
AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger
and kernel_agent_bindings
):
selected_agent_ids = [
binding.agent_alloc_ctx.agent_id
for binding in kernel_agent_bindings
if binding.agent_alloc_ctx.agent_id is not None
]
await self.registry.sync_agent_resource(self.db, selected_agent_ids)
await redis_helper.execute(
self.redis_live,
lambda r: r.hset(
Expand Down Expand Up @@ -461,7 +446,6 @@ async def _update():
len(cancelled_sessions),
)
zero = ResourceSlot()
num_scheduled = 0
kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = []

while len(pending_sessions) > 0:
Expand Down Expand Up @@ -1041,11 +1025,6 @@ async def _finalize_scheduled() -> None:
)
return kernel_agent_bindings

kernel_agent_bindings: list[KernelAgentBinding] = []
for kernel_row in sess_ctx.kernels:
kernel_agent_bindings.append(KernelAgentBinding(kernel_row, agent_alloc_ctx, set()))
return kernel_agent_bindings

async def _schedule_multi_node_session(
self,
sched_ctx: SchedulingContext,
Expand Down

0 comments on commit caa2113

Please sign in to comment.