Skip to content

Commit

Permalink
fix: Add missing resolver of Agent.compute_containers GraphQL field #…
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa authored Nov 19, 2024
1 parent a474b68 commit 224879b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
1 change: 1 addition & 0 deletions changes/3011.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `Agent.compute_containers` GraphQL field by adding missing resolver
24 changes: 20 additions & 4 deletions src/ai/backend/manager/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
simple_db_mutate,
)
from .group import association_groups_users
from .kernel import AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, KernelRow, kernels
from .kernel import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
ComputeContainer,
KernelRow,
KernelStatus,
kernels,
)
from .keypair import keypairs
from .minilang.ordering import OrderSpecItem, QueryOrderParser
from .minilang.queryfilter import FieldSpecItem, QueryFilterParser, enum_field_getter
Expand Down Expand Up @@ -157,9 +163,7 @@ class Meta:
cpu_cur_pct = graphene.Float()
mem_cur_bytes = graphene.Float()

compute_containers = graphene.List(
"ai.backend.manager.models.ComputeContainer", status=graphene.String()
)
compute_containers = graphene.List(ComputeContainer, status=graphene.String())

@classmethod
def from_row(
Expand Down Expand Up @@ -195,6 +199,18 @@ def from_row(
used_tpu_slots=float(row["occupied_slots"].get("tpu.device", 0)),
)

async def resolve_compute_containers(
self, info: graphene.ResolveInfo, *, status: Optional[str] = None
) -> list[ComputeContainer]:
ctx: GraphQueryContext = info.context
_status = KernelStatus[status] if status is not None else None
loader = ctx.dataloader_manager.get_loader(
ctx,
"ComputeContainer.by_agent_id",
status=_status,
)
return await loader.load(self.id)

async def resolve_live_stat(self, info: graphene.ResolveInfo) -> Any:
ctx: GraphQueryContext = info.context
loader = ctx.dataloader_manager.get_loader(ctx, "Agent.live_stat")
Expand Down
27 changes: 27 additions & 0 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import (
AccessKey,
AgentId,
BinarySize,
ClusterMode,
KernelId,
Expand Down Expand Up @@ -70,6 +71,7 @@
StructuredJSONObjectListColumn,
URLColumn,
batch_multiresult,
batch_multiresult_in_scalar_stream,
batch_result,
mapper_registry,
)
Expand Down Expand Up @@ -1010,6 +1012,31 @@ async def batch_load_by_session(
lambda row: row.session_id,
)

@classmethod
async def batch_load_by_agent_id(
cls,
ctx: GraphQueryContext,
agent_ids: Sequence[AgentId],
*,
status: Optional[KernelStatus] = None,
) -> Sequence[Sequence[ComputeContainer]]:
query_stmt = (
sa.select(KernelRow)
.where(KernelRow.agent.in_(agent_ids))
.options(selectinload(KernelRow.image_row).options(selectinload(ImageRow.aliases)))
)
if status is not None:
query_stmt = query_stmt.where(KernelRow.status == status)
async with ctx.db.begin_readonly_session() as db_session:
return await batch_multiresult_in_scalar_stream(
ctx,
db_session,
query_stmt,
cls,
agent_ids,
lambda row: row.agent,
)

@classmethod
async def batch_load_detail(
cls,
Expand Down

0 comments on commit 224879b

Please sign in to comment.