Skip to content

Commit

Permalink
feat: schedule function returns list of kernel-agent binding
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Aug 8, 2024
1 parent b590466 commit 2d124a5
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions src/ai/backend/manager/scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ async def _schedule_in_sgroup(
self,
sched_ctx: SchedulingContext,
sgroup_name: str,
) -> None:
) -> list[KernelAgentBinding]:
async def _apply_cancellation(
db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout"
):
Expand Down Expand Up @@ -432,7 +432,8 @@ async def _update():
len(cancelled_sessions),
)
zero = ResourceSlot()
num_scheduled = 0
kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = []

while len(pending_sessions) > 0:
async with self.db.begin_readonly_session() as db_sess:
candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name)
Expand All @@ -446,7 +447,7 @@ async def _update():
if picked_session_id is None:
# no session is picked.
# continue to next sgroup.
return
return kernel_agent_bindings_in_sgroup
for picked_idx, sess_ctx in enumerate(pending_sessions):
if sess_ctx.id == picked_session_id:
break
Expand Down Expand Up @@ -657,7 +658,7 @@ async def _update_session_status_data() -> None:
try:
match schedulable_sess.cluster_mode:
case ClusterMode.SINGLE_NODE:
await self._schedule_single_node_session(
kernel_agent_bindings = await self._schedule_single_node_session(
sched_ctx,
scheduler,
sgroup_name,
Expand All @@ -667,7 +668,7 @@ async def _update_session_status_data() -> None:
check_results,
)
case ClusterMode.MULTI_NODE:
await self._schedule_multi_node_session(
kernel_agent_bindings = await self._schedule_multi_node_session(
sched_ctx,
scheduler,
sgroup_name,
Expand Down Expand Up @@ -701,9 +702,11 @@ async def _update_session_status_data() -> None:
# _schedule_{single,multi}_node_session() already handle general exceptions.
# Proceed to the next pending session and come back later
continue
num_scheduled += 1
if num_scheduled > 0:
else:
kernel_agent_bindings_in_sgroup.extend(kernel_agent_bindings)
if kernel_agent_bindings_in_sgroup:
await self.event_producer.produce_event(DoPrepareEvent())
return kernel_agent_bindings_in_sgroup

async def _filter_agent_by_container_limit(
self, candidate_agents: list[AgentRow]
Expand Down Expand Up @@ -736,12 +739,13 @@ async def _schedule_single_node_session(
sess_ctx: SessionRow,
agent_selection_resource_priority: list[str],
check_results: List[Tuple[str, Union[Exception, PredicateResult]]],
) -> None:
) -> list[KernelAgentBinding]:
"""
Finds and assigns an agent having resources enough to host the entire session.
"""
log_fmt = _log_fmt.get("")
log_args = _log_args.get(tuple())
kernel_agent_bindings: list[KernelAgentBinding] = []

try:
requested_architectures = set(k.architecture for k in sess_ctx.kernels)
Expand Down Expand Up @@ -892,6 +896,10 @@ async def _schedule_single_node_session(
agent_id,
sess_ctx.requested_slots,
)
for kernel_row in sess_ctx.kernels:
kernel_agent_bindings.append(
KernelAgentBinding(kernel_row, agent_alloc_ctx, set())
)
except InstanceNotAvailable as sched_failure:
log.debug(log_fmt + "no-available-instances", *log_args)

Expand Down Expand Up @@ -1001,6 +1009,7 @@ async def _finalize_scheduled() -> None:
await self.registry.event_producer.produce_event(
SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id),
)
return kernel_agent_bindings

async def _schedule_multi_node_session(
self,
Expand All @@ -1011,7 +1020,7 @@ async def _schedule_multi_node_session(
sess_ctx: SessionRow,
agent_selection_resource_priority: list[str],
check_results: List[Tuple[str, Union[Exception, PredicateResult]]],
) -> None:
) -> list[KernelAgentBinding]:
"""
Finds and assigns agents having resources enough to host each kernel in the session.
"""
Expand Down Expand Up @@ -1239,6 +1248,7 @@ async def _finalize_scheduled() -> None:
await self.registry.event_producer.produce_event(
SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id),
)
return kernel_agent_bindings

async def prepare(
self,
Expand Down

0 comments on commit 2d124a5

Please sign in to comment.