diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 6eedc637f1..6c3768f389 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -361,7 +361,7 @@ async def _schedule_in_sgroup( self, sched_ctx: SchedulingContext, sgroup_name: str, - ) -> None: + ) -> list[KernelAgentBinding]: async def _apply_cancellation( db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout" ): @@ -432,7 +432,8 @@ async def _update(): len(cancelled_sessions), ) zero = ResourceSlot() - num_scheduled = 0 + kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] + while len(pending_sessions) > 0: async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) @@ -446,7 +447,7 @@ async def _update(): if picked_session_id is None: # no session is picked. # continue to next sgroup. - return + return kernel_agent_bindings_in_sgroup for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.id == picked_session_id: break @@ -657,7 +658,7 @@ async def _update_session_status_data() -> None: try: match schedulable_sess.cluster_mode: case ClusterMode.SINGLE_NODE: - await self._schedule_single_node_session( + kernel_agent_bindings = await self._schedule_single_node_session( sched_ctx, scheduler, sgroup_name, @@ -667,7 +668,7 @@ async def _update_session_status_data() -> None: check_results, ) case ClusterMode.MULTI_NODE: - await self._schedule_multi_node_session( + kernel_agent_bindings = await self._schedule_multi_node_session( sched_ctx, scheduler, sgroup_name, @@ -701,9 +702,11 @@ async def _update_session_status_data() -> None: # _schedule_{single,multi}_node_session() already handle general exceptions. # Proceed to the next pending session and come back later continue - num_scheduled += 1 - if num_scheduled > 0: + else: + kernel_agent_bindings_in_sgroup.extend(kernel_agent_bindings) + if kernel_agent_bindings_in_sgroup: await self.event_producer.produce_event(DoPrepareEvent()) + return kernel_agent_bindings_in_sgroup async def _filter_agent_by_container_limit( self, candidate_agents: list[AgentRow] @@ -736,12 +739,13 @@ async def _schedule_single_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns an agent having resources enough to host the entire session. """ log_fmt = _log_fmt.get("") log_args = _log_args.get(tuple()) + kernel_agent_bindings: list[KernelAgentBinding] = [] try: requested_architectures = set(k.architecture for k in sess_ctx.kernels) @@ -892,6 +896,10 @@ async def _schedule_single_node_session( agent_id, sess_ctx.requested_slots, ) + for kernel_row in sess_ctx.kernels: + kernel_agent_bindings.append( + KernelAgentBinding(kernel_row, agent_alloc_ctx, set()) + ) except InstanceNotAvailable as sched_failure: log.debug(log_fmt + "no-available-instances", *log_args) @@ -1001,6 +1009,7 @@ async def _finalize_scheduled() -> None: await self.registry.event_producer.produce_event( SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + return kernel_agent_bindings async def _schedule_multi_node_session( self, @@ -1011,7 +1020,7 @@ async def _schedule_multi_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns agents having resources enough to host each kernel in the session. """ @@ -1239,6 +1248,7 @@ async def _finalize_scheduled() -> None: await self.registry.event_producer.produce_event( SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + return kernel_agent_bindings async def prepare( self,