Skip to content

Commit

Permalink
handle many kernel status
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Jun 7, 2024
1 parent d4d397f commit 98e4e00
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
30 changes: 26 additions & 4 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,8 @@ async def sync_kernel_registry(
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Collection[str],
pulling_kernels: Collection[str],
running_kernels: Collection[str],
terminating_kernels: Collection[str],
) -> dict[str, Any]:
Expand Down Expand Up @@ -547,10 +549,30 @@ async def sync_and_get_kernels(
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=False,
)
elif kernel_id not in running_kernels:
# The kernel status is not 'running' or 'terminating' in truth.
# It should be terminated.
if kernel_id not in self.agent.terminating_kernels:
elif kernel_id in running_kernels:
pass
elif kernel_id in preparing_kernels:
# kernel_registry may not have `preparing` state kernels.
pass
elif kernel_id in pulling_kernels:
# kernel_registry does not have `pulling` state kernels.
# Let's just skip it.
pass
else:
# This kernel is not alive according to the truth data.
# The kernel should be destroyed or cleaned
if kernel_id in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.CLEAN,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=True,
)
elif kernel_id in self.agent.restarting_kernels:
pass
else:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
Expand Down
48 changes: 36 additions & 12 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,11 +3071,15 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None:
async def _sync_agent_resource_and_get_kerenels(
self,
agent_id: AgentId,
preparing_kernels: Collection[KernelId],
pulling_kernels: Collection[KernelId],
running_kernels: Collection[KernelId],
terminating_kernels: Collection[KernelId],
) -> AgentKernelRegistryByStatus:
async with self.agent_cache.rpc_context(agent_id) as rpc:
resp: dict[str, Any] = await rpc.call.sync_and_get_kernels(
preparing_kernels,
pulling_kernels,
running_kernels,
terminating_kernels,
)
Expand All @@ -3094,31 +3098,49 @@ async def sync_agent_resource(
.options(
selectinload(
AgentRow.kernels.and_(
KernelRow.status.in_([KernelStatus.RUNNING, KernelStatus.TERMINATING])
KernelRow.status.in_([
KernelStatus.PREPARING,
KernelStatus.PULLING,
KernelStatus.RUNNING,
KernelStatus.TERMINATING,
])
),
).options(load_only(KernelRow.id, KernelRow.status))
)
)
async with SASession(bind=db_connection) as db_session:
for _agent_row in await db_session.scalars(stmt):
agent_row = cast(AgentRow, _agent_row)
preparing_kernels: list[KernelId] = []
pulling_kernels: list[KernelId] = []
running_kernels: list[KernelId] = []
terminating_kernels: list[KernelId] = []
for kernel in agent_row.kernels:
kernel_status = cast(KernelStatus, kernel.status)
match kernel_status:
case KernelStatus.PREPARING:
preparing_kernels.append(KernelId(kernel.id))
case KernelStatus.PULLING:
pulling_kernels.append(KernelId(kernel.id))
case KernelStatus.RUNNING:
running_kernels.append(KernelId(kernel.id))
case KernelStatus.TERMINATING:
terminating_kernels.append(KernelId(kernel.id))
case _:
continue
agent_kernel_by_status[AgentId(agent_row.id)] = {
"running_kernels": [
KernelId(kern.id)
for kern in agent_row.kernels
if kern.status == KernelStatus.RUNNING
],
"terminating_kernels": [
KernelId(kern.id)
for kern in agent_row.kernels
if kern.status == KernelStatus.TERMINATING
],
"preparing_kernels": preparing_kernels,
"pulling_kernels": pulling_kernels,
"running_kernels": running_kernels,
"terminating_kernels": terminating_kernels,
}
tasks = []
for agent_id in agent_ids:
tasks.append(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["preparing_kernels"],
agent_kernel_by_status[agent_id]["pulling_kernels"],
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
Expand All @@ -3136,7 +3158,9 @@ async def sync_agent_resource(
agent_errors,
)
else:
assert isinstance(resp, AgentKernelRegistryByStatus)
assert isinstance(
resp, AgentKernelRegistryByStatus
), f"response should be `AgentKernelRegistryByStatus`, not {type(resp)}"
result[aid] = resp
return result

Expand Down

0 comments on commit 98e4e00

Please sign in to comment.