From 8b8ef48d03fbdf2d9d5b47ce442a6e8c8d8542e2 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sun, 26 May 2024 12:10:12 +0900 Subject: [PATCH] impl basic APIs and client side logic --- src/ai/backend/agent/server.py | 5 +- src/ai/backend/common/validators.py | 17 +++ src/ai/backend/manager/config.py | 4 + src/ai/backend/manager/registry.py | 113 +++++++++++++++++- .../backend/manager/scheduler/dispatcher.py | 75 +++++++++--- src/ai/backend/manager/types.py | 9 ++ 6 files changed, 200 insertions(+), 23 deletions(-) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 81994c1b0be..37c96030d5c 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -487,7 +487,7 @@ async def sync_and_get_kernels( terminating_kernels_from_truth: Collection[str], ) -> dict[str, Any]: """ - Sync kernel_registry and containers to truth + Sync kernel_registry and containers to truth data and return kernel infos whose status is irreversible. """ @@ -553,8 +553,7 @@ async def sync_and_get_kernels( kernel_id, kernel_obj.session_id, LifecycleEvent.DESTROY, - kernel_obj.termination_reason - or KernelLifecycleEventReason.ALREADY_TERMINATED, + kernel_obj.termination_reason or KernelLifecycleEventReason.UNKNOWN, suppress_events=True, ) diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 6f50415b3c0..4a76c45d67b 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum: self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) +class EnumList(t.Trafaret, Generic[T_enum]): + def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None: + self.enum_cls = enum_cls + self.use_name = use_name + + def check_and_return(self, value: Any) -> list[T_enum]: + try: + if self.use_name: + return [self.enum_cls[val] for val in value] + else: + return [self.enum_cls(val) for val in value] + except TypeError: + self._failure("cannot parse value into list", value=value) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + class JSONString(t.Trafaret): def check_and_return(self, value: Any) -> dict: try: diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 63a39c6cf8c..3a3c0c837fc 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -225,6 +225,7 @@ from .api.exceptions import ObjectNotFound, ServerMisconfiguredError from .models.session import SessionStatus from .pglock import PgAdvisoryLock +from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -296,6 +297,9 @@ "agent-selection-resource-priority", default=["cuda", "rocm", "tpu", "cpu", "mem"], ): t.List(t.String), + t.Key( + "agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS + ): tx.EnumList(AgentResourceSyncTrigger), t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String, t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[ diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 942f6d6d8e6..6a273232905 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -174,16 +174,18 @@ reenter_txn_session, sql_json_merge, ) -from .types import UserScope +from .types import AgentResourceSyncTrigger, UserScope if TYPE_CHECKING: from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection + from sqlalchemy.ext.asyncio import AsyncSession as SASession from ai.backend.common.auth import PublicKey, SecretKey from ai.backend.common.events import EventDispatcher, EventProducer from .agent_cache import AgentRPCCache + from .exceptions import ErrorDetail from .models.storage import StorageSessionManager from .scheduler.types import AgentAllocationContext, KernelAgentBinding, SchedulingContext @@ -1637,6 +1639,10 @@ async def _create_kernels_in_one_agent( is_local = image_info["is_local"] resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"] auto_pull = image_info["auto_pull"] + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-policy"], + ) assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None @@ -1655,6 +1661,27 @@ async def _update_kernel() -> None: await execute_with_retry(_update_kernel) + if AgentResourceSyncTrigger.BEFORE_KERNEL_CREATION in agent_resource_sync_trigger: + async with self.db.begin() as db_conn: + results = await self.sync_agent_resource( + [agent_alloc_ctx.agent_id], db_connection=db_conn + ) + for agent_id, result in results.items(): + match result: + case AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ): + pass + case MultiAgentError(): + pass + case _: + pass + pass + async with SASession(bind=db_conn) as db_session: + pass + async with self.agent_cache.rpc_context( agent_alloc_ctx.agent_id, order_key=str(scheduled_session.id), @@ -1730,9 +1757,41 @@ async def _update_kernel() -> None: except (asyncio.TimeoutError, asyncio.CancelledError): log.warning("_create_kernels_in_one_agent(s:{}) cancelled", scheduled_session.id) except Exception as e: + ex = e + err_info = convert_to_status_data(ex, self.debug) + + def _has_insufficient_resource_err(_err_info: ErrorDetail) -> bool: + if _err_info["name"] == "InsufficientResource": + return True + if (sub_errors := _err_info.get("collection")) is not None: + for suberr in sub_errors: + if _has_insufficient_resource_err(suberr): + return True + return False + + if AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger: + if _has_insufficient_resource_err(err_info["error"]): + async with self.db.begin() as db_conn: + results = await self.sync_agent_resource( + [agent_alloc_ctx.agent_id], db_connection=db_conn + ) + for agent_id, result in results.items(): + match result: + case AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ): + pass + case MultiAgentError(): + pass + case _: + pass + pass + async with SASession(bind=db_conn) as db_session: + pass # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. - ex = e for binding in items: kernel_id = binding.kernel.id @@ -1756,7 +1815,7 @@ async def _update_failure() -> None: ), # ["PULLING", "PREPARING"] }, ), - status_data=convert_to_status_data(ex, self.debug), + status_data=err_info, ) ) await db_sess.execute(query) @@ -3028,7 +3087,7 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None: (str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels ]) - async def sync_and_get_kernels( + async def _sync_agent_resource_and_get_kerenels( self, agent_id: AgentId, running_kernels: Collection[KernelId], @@ -3041,6 +3100,52 @@ async def sync_and_get_kernels( ) return AgentKernelRegistryByStatus.from_json(resp) + async def sync_agent_resource( + self, + agent_ids: Collection[AgentId], + *, + db_connection: SAConnection, + ) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]: + result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {} + agent_kernel_by_status: dict[AgentId, dict[str, list[KernelId]]] = {} + stmt = sa.select(KernelRow).where( + (KernelRow.agent.in_(agent_ids)) + & (KernelRow.status.in_([KernelStatus.RUNNING, KernelStatus.TERMINATING])) + ) + for kernel_row in cast(list[KernelRow], await SASession(bind=db_connection).scalars(stmt)): + if kernel_row.agent not in agent_kernel_by_status: + agent_kernel_by_status[kernel_row.agent] = { + "running_kernels": [], + "terminating_kernels": [], + } + agent_kernel_by_status[kernel_row.agent]["running_kernels"].append(kernel_row.id) + agent_kernel_by_status[kernel_row.agent]["terminating_kernels"].append(kernel_row.id) + tasks = [] + for agent_id in agent_ids: + tasks.append( + self._sync_agent_resource_and_get_kerenels( + agent_id, + agent_kernel_by_status[agent_id]["running_kernels"], + agent_kernel_by_status[agent_id]["terminating_kernels"], + ) + ) + responses = await asyncio.gather(*tasks, return_exceptions=True) + for aid, resp in zip(agent_ids, responses): + agent_errors = [] + if isinstance(resp, aiotools.TaskGroupError): + agent_errors.extend(resp.__errors__) + elif isinstance(result, Exception): + agent_errors.append(resp) + if agent_errors: + result[aid] = MultiAgentError( + "agent(s) raise errors during kernel resource sync", + agent_errors, + ) + else: + assert isinstance(resp, AgentKernelRegistryByStatus) + result[aid] = resp + return result + async def mark_kernel_terminated( self, kernel_id: KernelId, diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 7441801959f..2fa8eab76fb 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -20,6 +20,7 @@ Sequence, Tuple, Union, + cast, ) import aiotools @@ -55,6 +56,7 @@ from ai.backend.common.plugin.hook import PASSED, HookResult from ai.backend.common.types import ( AgentId, + AgentKernelRegistryByStatus, ClusterMode, RedisConnectionInfo, ResourceSlot, @@ -68,7 +70,7 @@ from ..api.exceptions import GenericBadRequest, InstanceNotAvailable, SessionNotFound from ..defs import SERVICE_MAX_RETRIES, LockID -from ..exceptions import convert_to_status_data +from ..exceptions import MultiAgentError, convert_to_status_data from ..models import ( AgentRow, AgentStatus, @@ -88,6 +90,7 @@ ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..types import AgentResourceSyncTrigger from .predicates import ( check_concurrency, check_dependencies, @@ -259,6 +262,10 @@ async def schedule( log.debug("schedule(): triggered") manager_id = self.local_config["manager"]["id"] redis_key = f"manager.{manager_id}.schedule" + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-policy"], + ) def _pipeline(r: Redis) -> RedisPipeline: pipe = r.pipeline() @@ -287,10 +294,6 @@ def _pipeline(r: Redis) -> RedisPipeline: # as its individual steps are composed of many short-lived transactions. async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60): async with self.db.begin_readonly_session() as db_sess: - # query = ( - # sa.select(ScalingGroupRow) - # .join(ScalingGroupRow.agents.and_(AgentRow.status == AgentStatus.ALIVE)) - # ) query = ( sa.select(AgentRow.scaling_group) .where(AgentRow.status == AgentStatus.ALIVE) @@ -298,12 +301,15 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] + produce_do_prepare = False for sgroup_name in schedulable_scaling_groups: try: - await self._schedule_in_sgroup( + kernel_agent_bindings = await self._schedule_in_sgroup( sched_ctx, sgroup_name, ) + if kernel_agent_bindings: + produce_do_prepare = True await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -314,6 +320,35 @@ def _pipeline(r: Redis) -> RedisPipeline: ) except Exception as e: log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + async with self.db.begin() as db_conn: + results = await self.registry.sync_agent_resource( + selected_agent_ids, db_connection=db_conn + ) + for agent_id, result in results.items(): + match result: + case AgentKernelRegistryByStatus( + all_running_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ): + pass + case MultiAgentError(): + pass + case _: + pass + pass + async with SASession(bind=db_conn) as db_session: + pass await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -322,6 +357,8 @@ def _pipeline(r: Redis) -> RedisPipeline: datetime.now(tzutc()).isoformat(), ), ) + if produce_do_prepare: + await self.event_producer.produce_event(DoPrepareEvent()) except DBAPIError as e: if getattr(e.orig, "pgcode", None) == "55P03": log.info( @@ -355,7 +392,7 @@ async def _schedule_in_sgroup( self, sched_ctx: SchedulingContext, sgroup_name: str, - ) -> None: + ) -> list[KernelAgentBinding]: async def _apply_cancellation( db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout" ): @@ -426,7 +463,7 @@ async def _update(): len(cancelled_sessions), ) zero = ResourceSlot() - num_scheduled = 0 + kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] while len(pending_sessions) > 0: async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) @@ -440,7 +477,7 @@ async def _update(): if picked_session_id is None: # no session is picked. # continue to next sgroup. - return + return [] for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.id == picked_session_id: break @@ -651,7 +688,7 @@ async def _update_session_status_data() -> None: try: match schedulable_sess.cluster_mode: case ClusterMode.SINGLE_NODE: - await self._schedule_single_node_session( + kernel_agent_bindings = await self._schedule_single_node_session( sched_ctx, scheduler, sgroup_name, @@ -661,7 +698,7 @@ async def _update_session_status_data() -> None: check_results, ) case ClusterMode.MULTI_NODE: - await self._schedule_multi_node_session( + kernel_agent_bindings = await self._schedule_multi_node_session( sched_ctx, scheduler, sgroup_name, @@ -695,9 +732,9 @@ async def _update_session_status_data() -> None: # _schedule_{single,multi}_node_session() already handle general exceptions. # Proceed to the next pending session and come back later continue - num_scheduled += 1 - if num_scheduled > 0: - await self.event_producer.produce_event(DoPrepareEvent()) + else: + kernel_agent_bindings_in_sgroup.extend(kernel_agent_bindings) + return kernel_agent_bindings_in_sgroup async def _filter_agent_by_container_limit( self, candidate_agents: list[AgentRow] @@ -730,7 +767,7 @@ async def _schedule_single_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns an agent having resources enough to host the entire session. """ @@ -994,6 +1031,11 @@ async def _finalize_scheduled() -> None: SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + kernel_agent_bindings: list[KernelAgentBinding] = [] + for kernel_row in sess_ctx.kernels: + kernel_agent_bindings.append(KernelAgentBinding(kernel_row, agent_alloc_ctx, set())) + return kernel_agent_bindings + async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, @@ -1003,7 +1045,7 @@ async def _schedule_multi_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns agents having resources enough to host each kernel in the session. """ @@ -1231,6 +1273,7 @@ async def _finalize_scheduled() -> None: await self.registry.event_producer.produce_event( SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + return kernel_agent_bindings async def prepare( self, diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index 7d413594deb..6a52a6cef44 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -41,3 +41,12 @@ class UserScope: class DistributedLockFactory(Protocol): def __call__(self, lock_id: LockID, lifetime_hint: float) -> AbstractDistributedLock: ... + + +class AgentResourceSyncTrigger(enum.StrEnum): + AFTER_SCHEDULING = "after-scheduling" + BEFORE_KERNEL_CREATION = "before-kernel-creation" + ON_CREATION_FAILURE = "on-creation-failure" + + +DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = []