Skip to content

Commit

Permalink
impl basic APIs and client side logic
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed May 26, 2024
1 parent 1eb020a commit 8b8ef48
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 23 deletions.
5 changes: 2 additions & 3 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ async def sync_and_get_kernels(
terminating_kernels_from_truth: Collection[str],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth
Sync kernel_registry and containers to truth data
and return kernel infos whose status is irreversible.
"""

Expand Down Expand Up @@ -553,8 +553,7 @@ async def sync_and_get_kernels(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
kernel_obj.termination_reason or KernelLifecycleEventReason.UNKNOWN,
suppress_events=True,
)

Expand Down
17 changes: 17 additions & 0 deletions src/ai/backend/common/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum:
self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value)


class EnumList(t.Trafaret, Generic[T_enum]):
def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None:
self.enum_cls = enum_cls
self.use_name = use_name

def check_and_return(self, value: Any) -> list[T_enum]:
try:
if self.use_name:
return [self.enum_cls[val] for val in value]
else:
return [self.enum_cls(val) for val in value]
except TypeError:
self._failure("cannot parse value into list", value=value)
except (KeyError, ValueError):
self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value)


class JSONString(t.Trafaret):
def check_and_return(self, value: Any) -> dict:
try:
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@
from .api.exceptions import ObjectNotFound, ServerMisconfiguredError
from .models.session import SessionStatus
from .pglock import PgAdvisoryLock
from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

Expand Down Expand Up @@ -296,6 +297,9 @@
"agent-selection-resource-priority",
default=["cuda", "rocm", "tpu", "cpu", "mem"],
): t.List(t.String),
t.Key(
"agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS
): tx.EnumList(AgentResourceSyncTrigger),
t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String,
t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB
tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[
Expand Down
113 changes: 109 additions & 4 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,18 @@
reenter_txn_session,
sql_json_merge,
)
from .types import UserScope
from .types import AgentResourceSyncTrigger, UserScope

if TYPE_CHECKING:
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection
from sqlalchemy.ext.asyncio import AsyncSession as SASession

from ai.backend.common.auth import PublicKey, SecretKey
from ai.backend.common.events import EventDispatcher, EventProducer

from .agent_cache import AgentRPCCache
from .exceptions import ErrorDetail
from .models.storage import StorageSessionManager
from .scheduler.types import AgentAllocationContext, KernelAgentBinding, SchedulingContext

Expand Down Expand Up @@ -1637,6 +1639,10 @@ async def _create_kernels_in_one_agent(
is_local = image_info["is_local"]
resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"]
auto_pull = image_info["auto_pull"]
agent_resource_sync_trigger = cast(
list[AgentResourceSyncTrigger],
self.local_config["manager"]["agent-resource-sync-policy"],
)
assert agent_alloc_ctx.agent_id is not None
assert scheduled_session.id is not None

Expand All @@ -1655,6 +1661,27 @@ async def _update_kernel() -> None:

await execute_with_retry(_update_kernel)

if AgentResourceSyncTrigger.BEFORE_KERNEL_CREATION in agent_resource_sync_trigger:
async with self.db.begin() as db_conn:
results = await self.sync_agent_resource(
[agent_alloc_ctx.agent_id], db_connection=db_conn
)
for agent_id, result in results.items():
match result:
case AgentKernelRegistryByStatus(
all_running_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
):
pass
case MultiAgentError():
pass
case _:
pass
pass
async with SASession(bind=db_conn) as db_session:
pass

async with self.agent_cache.rpc_context(
agent_alloc_ctx.agent_id,
order_key=str(scheduled_session.id),
Expand Down Expand Up @@ -1730,9 +1757,41 @@ async def _update_kernel() -> None:
except (asyncio.TimeoutError, asyncio.CancelledError):
log.warning("_create_kernels_in_one_agent(s:{}) cancelled", scheduled_session.id)
except Exception as e:
ex = e
err_info = convert_to_status_data(ex, self.debug)

def _has_insufficient_resource_err(_err_info: ErrorDetail) -> bool:
if _err_info["name"] == "InsufficientResource":
return True
if (sub_errors := _err_info.get("collection")) is not None:
for suberr in sub_errors:
if _has_insufficient_resource_err(suberr):
return True
return False

if AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger:
if _has_insufficient_resource_err(err_info["error"]):
async with self.db.begin() as db_conn:
results = await self.sync_agent_resource(
[agent_alloc_ctx.agent_id], db_connection=db_conn
)
for agent_id, result in results.items():
match result:
case AgentKernelRegistryByStatus(
all_running_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
):
pass
case MultiAgentError():
pass
case _:
pass
pass
async with SASession(bind=db_conn) as db_session:
pass
# The agent has already cancelled or issued the destruction lifecycle event
# for this batch of kernels.
ex = e
for binding in items:
kernel_id = binding.kernel.id

Expand All @@ -1756,7 +1815,7 @@ async def _update_failure() -> None:
), # ["PULLING", "PREPARING"]
},
),
status_data=convert_to_status_data(ex, self.debug),
status_data=err_info,
)
)
await db_sess.execute(query)
Expand Down Expand Up @@ -3028,7 +3087,7 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None:
(str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels
])

async def sync_and_get_kernels(
async def _sync_agent_resource_and_get_kerenels(
self,
agent_id: AgentId,
running_kernels: Collection[KernelId],
Expand All @@ -3041,6 +3100,52 @@ async def sync_and_get_kernels(
)
return AgentKernelRegistryByStatus.from_json(resp)

async def sync_agent_resource(
self,
agent_ids: Collection[AgentId],
*,
db_connection: SAConnection,
) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]:
result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {}
agent_kernel_by_status: dict[AgentId, dict[str, list[KernelId]]] = {}
stmt = sa.select(KernelRow).where(
(KernelRow.agent.in_(agent_ids))
& (KernelRow.status.in_([KernelStatus.RUNNING, KernelStatus.TERMINATING]))
)
for kernel_row in cast(list[KernelRow], await SASession(bind=db_connection).scalars(stmt)):
if kernel_row.agent not in agent_kernel_by_status:
agent_kernel_by_status[kernel_row.agent] = {
"running_kernels": [],
"terminating_kernels": [],
}
agent_kernel_by_status[kernel_row.agent]["running_kernels"].append(kernel_row.id)
agent_kernel_by_status[kernel_row.agent]["terminating_kernels"].append(kernel_row.id)
tasks = []
for agent_id in agent_ids:
tasks.append(
self._sync_agent_resource_and_get_kerenels(
agent_id,
agent_kernel_by_status[agent_id]["running_kernels"],
agent_kernel_by_status[agent_id]["terminating_kernels"],
)
)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for aid, resp in zip(agent_ids, responses):
agent_errors = []
if isinstance(resp, aiotools.TaskGroupError):
agent_errors.extend(resp.__errors__)
elif isinstance(result, Exception):
agent_errors.append(resp)
if agent_errors:
result[aid] = MultiAgentError(
"agent(s) raise errors during kernel resource sync",
agent_errors,
)
else:
assert isinstance(resp, AgentKernelRegistryByStatus)
result[aid] = resp
return result

async def mark_kernel_terminated(
self,
kernel_id: KernelId,
Expand Down
Loading

0 comments on commit 8b8ef48

Please sign in to comment.