diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 1a23348f8c..36353a2893 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -1627,11 +1627,11 @@ async def pull_image( @abstractmethod async def pull_image_in_background( self, - reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry, *, - timeout: Optional[float], + reporter: Optional[ProgressReporter] = None, + timeout: Optional[float] = None, ) -> None: """ Pull the given image from the given registry. diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index 920b76f555..7a9ee05b37 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -83,7 +83,7 @@ from ai.backend.logging.formatter import pretty from ..agent import ACTIVE_STATUS_SET, AbstractAgent, AbstractKernelCreationContext, ComputerContext -from ..exception import ContainerCreationError, ImagePullFailure, UnsupportedResource +from ..exception import ContainerCreationError, UnsupportedResource from ..fs import create_scratch_filesystem, destroy_scratch_filesystem from ..kernel import AbstractKernel, KernelFeatures from ..proxy import DomainSocketProxy, proxy_connection @@ -1515,11 +1515,11 @@ async def pull_image( async def pull_image_in_background( self, - reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry, *, - timeout: Optional[float], + reporter: Optional[ProgressReporter] = None, + timeout: Optional[float] = None, ) -> None: auth_config = None reg_user = registry_conf.get("username") @@ -1590,21 +1590,26 @@ def register_layer_id(resp: PullResponse) -> None: | "Downloading" ): register_layer_id(resp) - reporter.total_progress = len(layer_ids) - await reporter.update() + if reporter is not None: + reporter.total_progress = len(layer_ids) + await reporter.update() case "Pull complete" | "Already exists": register_layer_id(resp) - reporter.total_progress = len(layer_ids) - await reporter.update(1, force=True) + if reporter is not None: + reporter.total_progress = len(layer_ids) + await reporter.update(1, force=True) case status if status.startswith("Pulling from"): # Pulling has started. # Value of 'id' field in response dict does not represent layer id. - await reporter.update(message=status) + if reporter is not None: + await reporter.update(message=status) case status if status.startswith("Digest:") or status.startswith("Status:"): # Only 'status' field exists in response dict. - await reporter.update(message=status) + if reporter is not None: + await reporter.update(message=status) case _: - await reporter.update(message=resp["status"]) + if reporter is not None: + await reporter.update(message=resp["status"]) async def handle_response_for_each_layer( resp: PullResponse, layer_to_reporter_id_map: dict[str, UUID] @@ -1620,16 +1625,18 @@ async def update_to_subreporter( if id_ is None: return None if id_ not in layer_to_reporter_id_map: - task_id = uuid4() - subreporter = ProgressReporter( - bgtask_mgr.event_producer, - task_id, - cool_down_seconds=subreporter_cool_down_sec, - ) - reporter.register_subreporter(subreporter) + if reporter is not None: + task_id = uuid4() + subreporter = ProgressReporter( + bgtask_mgr.event_producer, + task_id, + cool_down_seconds=subreporter_cool_down_sec, + ) + reporter.register_subreporter(subreporter) else: - task_id = layer_to_reporter_id_map[id_] - subreporter = reporter.subreporters[task_id] + if reporter is not None: + task_id = layer_to_reporter_id_map[id_] + subreporter = reporter.subreporters[task_id] if current is not None: subreporter.current_progress = current if total is not None: @@ -1645,31 +1652,38 @@ async def update_to_subreporter( | "Extracting" as status ): await update_to_subreporter(resp, message=status) - reporter.total_progress = len(layer_to_reporter_id_map.keys()) - await reporter.update() + if reporter is not None: + reporter.total_progress = len(layer_to_reporter_id_map.keys()) + await reporter.update() case "Downloading" as status: detail = resp["progressDetail"] current = cast(int | None, detail.get("current")) total = cast(int | None, detail.get("total")) await update_to_subreporter(resp, message=status, current=current, total=total) - reporter.total_progress = len(layer_to_reporter_id_map.keys()) - await reporter.update() + if reporter is not None: + reporter.total_progress = len(layer_to_reporter_id_map.keys()) + await reporter.update() case ("Pull complete" | "Already exists") as status: await update_to_subreporter(resp, message=status, force=True) - reporter.total_progress = len(layer_to_reporter_id_map.keys()) - await reporter.update(1, force=True) + if reporter is not None: + reporter.total_progress = len(layer_to_reporter_id_map.keys()) + await reporter.update(1, force=True) case status if status.startswith("Pulling from"): # Pulling has started. # Value of 'id' field in response dict does not represent layer id. - await reporter.update(message=status) + if reporter is not None: + await reporter.update(message=status) case status if status.startswith("Digest:") or status.startswith("Status:"): # Only 'status' field exists in response dict. - await reporter.update(message=status) + if reporter is not None: + await reporter.update(message=status) case _: - await reporter.update(message=resp["status"]) + if reporter is not None: + await reporter.update(message=resp["status"]) async def handle_err_response(resp: PullErrorResponse) -> None: - await reporter.update(message=resp.get("error"), force=True) + if reporter is not None: + await reporter.update(message=resp.get("error"), force=True) layer_to_reporter_id_map: dict[str, UUID] = {} layer_ids: set[str] = set() @@ -1677,31 +1691,33 @@ async def handle_err_response(resp: PullErrorResponse) -> None: async for resp in docker.images.pull( image_ref.canonical, auth=auth_config, stream=True ): - match resp: - case dict() if resp.get("status"): - _resp = PullResponse(status=resp["status"]) - if detail := resp.get("progressDetail"): - _resp["progressDetail"] = detail - if progress := resp.get("progress"): - _resp["progress"] = progress - if id := resp.get("id"): - _resp["id"] = id - if do_report_per_layer: - await handle_response_for_each_layer(_resp, layer_to_reporter_id_map) - else: - await handle_response(_resp, layer_ids) - case dict() if resp.get("error"): - await handle_err_response( - PullErrorResponse( - error=resp["error"], - errorDetail=resp["errorDetail"], + try: + match resp: + case dict() if resp.get("status"): + _resp = PullResponse(status=resp["status"]) + if detail := resp.get("progressDetail"): + _resp["progressDetail"] = detail + if progress := resp.get("progress"): + _resp["progress"] = progress + if id := resp.get("id"): + _resp["id"] = id + if do_report_per_layer: + await handle_response_for_each_layer( + _resp, layer_to_reporter_id_map + ) + else: + await handle_response(_resp, layer_ids) + case dict() if resp.get("error"): + await handle_err_response( + PullErrorResponse( + error=resp["error"], + errorDetail=resp["errorDetail"], + ) ) - ) - raise ImagePullFailure(str(resp["error"])) - case _: - log.warning( - f"Unable to deserialize pulling response. skip. (resp:{str(resp)})" - ) + case _: + raise KeyError + except KeyError: + log.warning(f"Unable to deserialize pulling response. skip. (resp:{str(resp)})") async def check_image( self, image_ref: ImageRef, image_id: str, auto_pull: AutoPullBehavior diff --git a/src/ai/backend/agent/dummy/agent.py b/src/ai/backend/agent/dummy/agent.py index 7df8d0856c..4b08894b8c 100644 --- a/src/ai/backend/agent/dummy/agent.py +++ b/src/ai/backend/agent/dummy/agent.py @@ -293,11 +293,11 @@ async def pull_image( async def pull_image_in_background( self, - reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry, *, - timeout: Optional[float], + reporter: Optional[ProgressReporter] = None, + timeout: Optional[float] = None, ) -> None: return None diff --git a/src/ai/backend/agent/kubernetes/agent.py b/src/ai/backend/agent/kubernetes/agent.py index 103022cde8..87b68be6bd 100644 --- a/src/ai/backend/agent/kubernetes/agent.py +++ b/src/ai/backend/agent/kubernetes/agent.py @@ -1026,11 +1026,11 @@ async def pull_image( async def pull_image_in_background( self, - reporter: ProgressReporter, image_ref: ImageRef, registry_conf: ImageRegistry, *, - timeout: Optional[float], + reporter: Optional[ProgressReporter] = None, + timeout: Optional[float] = None, ) -> None: # TODO: Add support for appropriate image pulling mechanism on K8s pass diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 29ddf3373a..6e60231342 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -525,7 +525,7 @@ async def _pull(reporter: ProgressReporter, *, img_conf: ImageConfig) -> None: ) try: await self.agent.pull_image_in_background( - reporter, img_ref, img_conf["registry"], timeout=image_pull_timeout + img_ref, img_conf["registry"], reporter=reporter, timeout=image_pull_timeout ) except asyncio.TimeoutError: log.exception(f"Image pull timeout (img:{str(img_ref)},s:{image_pull_timeout})")