Skip to content

Commit

Permalink
Handle KeyError to keep pulling if response data is in unexpected shape
Browse files Browse the repository at this point in the history
  • Loading branch information
fregataa committed Nov 19, 2024
1 parent 15e7348 commit 8ab08cf
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 60 deletions.
4 changes: 2 additions & 2 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,11 +1627,11 @@ async def pull_image(
@abstractmethod
async def pull_image_in_background(
self,
reporter: ProgressReporter,
image_ref: ImageRef,
registry_conf: ImageRegistry,
*,
timeout: Optional[float],
reporter: Optional[ProgressReporter] = None,
timeout: Optional[float] = None,
) -> None:
"""
Pull the given image from the given registry.
Expand Down
122 changes: 69 additions & 53 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from ai.backend.logging.formatter import pretty

from ..agent import ACTIVE_STATUS_SET, AbstractAgent, AbstractKernelCreationContext, ComputerContext
from ..exception import ContainerCreationError, ImagePullFailure, UnsupportedResource
from ..exception import ContainerCreationError, UnsupportedResource
from ..fs import create_scratch_filesystem, destroy_scratch_filesystem
from ..kernel import AbstractKernel, KernelFeatures
from ..proxy import DomainSocketProxy, proxy_connection
Expand Down Expand Up @@ -1515,11 +1515,11 @@ async def pull_image(

async def pull_image_in_background(
self,
reporter: ProgressReporter,
image_ref: ImageRef,
registry_conf: ImageRegistry,
*,
timeout: Optional[float],
reporter: Optional[ProgressReporter] = None,
timeout: Optional[float] = None,
) -> None:
auth_config = None
reg_user = registry_conf.get("username")
Expand Down Expand Up @@ -1590,21 +1590,26 @@ def register_layer_id(resp: PullResponse) -> None:
| "Downloading"
):
register_layer_id(resp)
reporter.total_progress = len(layer_ids)
await reporter.update()
if reporter is not None:
reporter.total_progress = len(layer_ids)
await reporter.update()
case "Pull complete" | "Already exists":
register_layer_id(resp)
reporter.total_progress = len(layer_ids)
await reporter.update(1, force=True)
if reporter is not None:
reporter.total_progress = len(layer_ids)
await reporter.update(1, force=True)
case status if status.startswith("Pulling from"):
# Pulling has started.
# Value of 'id' field in response dict does not represent layer id.
await reporter.update(message=status)
if reporter is not None:
await reporter.update(message=status)
case status if status.startswith("Digest:") or status.startswith("Status:"):
# Only 'status' field exists in response dict.
await reporter.update(message=status)
if reporter is not None:
await reporter.update(message=status)
case _:
await reporter.update(message=resp["status"])
if reporter is not None:
await reporter.update(message=resp["status"])

async def handle_response_for_each_layer(
resp: PullResponse, layer_to_reporter_id_map: dict[str, UUID]
Expand All @@ -1620,16 +1625,18 @@ async def update_to_subreporter(
if id_ is None:
return None
if id_ not in layer_to_reporter_id_map:
task_id = uuid4()
subreporter = ProgressReporter(
bgtask_mgr.event_producer,
task_id,
cool_down_seconds=subreporter_cool_down_sec,
)
reporter.register_subreporter(subreporter)
if reporter is not None:
task_id = uuid4()
subreporter = ProgressReporter(
bgtask_mgr.event_producer,
task_id,
cool_down_seconds=subreporter_cool_down_sec,
)
reporter.register_subreporter(subreporter)
else:
task_id = layer_to_reporter_id_map[id_]
subreporter = reporter.subreporters[task_id]
if reporter is not None:
task_id = layer_to_reporter_id_map[id_]
subreporter = reporter.subreporters[task_id]
if current is not None:
subreporter.current_progress = current
if total is not None:
Expand All @@ -1645,63 +1652,72 @@ async def update_to_subreporter(
| "Extracting" as status
):
await update_to_subreporter(resp, message=status)
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update()
if reporter is not None:
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update()
case "Downloading" as status:
detail = resp["progressDetail"]
current = cast(int | None, detail.get("current"))
total = cast(int | None, detail.get("total"))
await update_to_subreporter(resp, message=status, current=current, total=total)
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update()
if reporter is not None:
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update()
case ("Pull complete" | "Already exists") as status:
await update_to_subreporter(resp, message=status, force=True)
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update(1, force=True)
if reporter is not None:
reporter.total_progress = len(layer_to_reporter_id_map.keys())
await reporter.update(1, force=True)
case status if status.startswith("Pulling from"):
# Pulling has started.
# Value of 'id' field in response dict does not represent layer id.
await reporter.update(message=status)
if reporter is not None:
await reporter.update(message=status)
case status if status.startswith("Digest:") or status.startswith("Status:"):
# Only 'status' field exists in response dict.
await reporter.update(message=status)
if reporter is not None:
await reporter.update(message=status)
case _:
await reporter.update(message=resp["status"])
if reporter is not None:
await reporter.update(message=resp["status"])

async def handle_err_response(resp: PullErrorResponse) -> None:
await reporter.update(message=resp.get("error"), force=True)
if reporter is not None:
await reporter.update(message=resp.get("error"), force=True)

layer_to_reporter_id_map: dict[str, UUID] = {}
layer_ids: set[str] = set()
async with closing_async(Docker()) as docker:
async for resp in docker.images.pull(
image_ref.canonical, auth=auth_config, stream=True
):
match resp:
case dict() if resp.get("status"):
_resp = PullResponse(status=resp["status"])
if detail := resp.get("progressDetail"):
_resp["progressDetail"] = detail
if progress := resp.get("progress"):
_resp["progress"] = progress
if id := resp.get("id"):
_resp["id"] = id
if do_report_per_layer:
await handle_response_for_each_layer(_resp, layer_to_reporter_id_map)
else:
await handle_response(_resp, layer_ids)
case dict() if resp.get("error"):
await handle_err_response(
PullErrorResponse(
error=resp["error"],
errorDetail=resp["errorDetail"],
try:
match resp:
case dict() if resp.get("status"):
_resp = PullResponse(status=resp["status"])
if detail := resp.get("progressDetail"):
_resp["progressDetail"] = detail
if progress := resp.get("progress"):
_resp["progress"] = progress
if id := resp.get("id"):
_resp["id"] = id
if do_report_per_layer:
await handle_response_for_each_layer(
_resp, layer_to_reporter_id_map
)
else:
await handle_response(_resp, layer_ids)
case dict() if resp.get("error"):
await handle_err_response(
PullErrorResponse(
error=resp["error"],
errorDetail=resp["errorDetail"],
)
)
)
raise ImagePullFailure(str(resp["error"]))
case _:
log.warning(
f"Unable to deserialize pulling response. skip. (resp:{str(resp)})"
)
case _:
raise KeyError
except KeyError:
log.warning(f"Unable to deserialize pulling response. skip. (resp:{str(resp)})")

async def check_image(
self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior
Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ async def pull_image(

async def pull_image_in_background(
self,
reporter: ProgressReporter,
image_ref: ImageRef,
registry_conf: ImageRegistry,
*,
timeout: Optional[float],
reporter: Optional[ProgressReporter] = None,
timeout: Optional[float] = None,
) -> None:
return None

Expand Down
4 changes: 2 additions & 2 deletions src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,11 +1026,11 @@ async def pull_image(

async def pull_image_in_background(
self,
reporter: ProgressReporter,
image_ref: ImageRef,
registry_conf: ImageRegistry,
*,
timeout: Optional[float],
reporter: Optional[ProgressReporter] = None,
timeout: Optional[float] = None,
) -> None:
# TODO: Add support for appropriate image pulling mechanism on K8s
pass
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None:
)
try:
await self.agent.pull_image_in_background(
reporter, img_ref, img_conf["registry"], timeout=image_pull_timeout
img_ref, img_conf["registry"], reporter=reporter, timeout=image_pull_timeout
)
except asyncio.TimeoutError:
log.exception(f"Image pull timeout (img:{str(img_ref)},s:{image_pull_timeout})")
Expand Down

0 comments on commit 8ab08cf

Please sign in to comment.