diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 1bbf45260c..7287ce3b25 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -573,6 +573,36 @@ def serve(self, logging_conf: Optional[dict] = None): else None ), ) + self._router.add_api_route( + "/v1/workers", + self.get_workers_info, + methods=["GET"], + dependencies=( + [Security(self._auth_service, scopes=["admin"])] + if self.is_authenticated() + else None + ), + ) + self._router.add_api_route( + "/v1/supervisor", + self.get_supervisor_info, + methods=["GET"], + dependencies=( + [Security(self._auth_service, scopes=["admin"])] + if self.is_authenticated() + else None + ), + ) + self._router.add_api_route( + "/v1/clusters", + self.abort_cluster, + methods=["DELETE"], + dependencies=( + [Security(self._auth_service, scopes=["admin"])] + if self.is_authenticated() + else None + ), + ) if XINFERENCE_DISABLE_METRICS: logger.info( @@ -1730,6 +1760,43 @@ async def confirm_and_remove_model( logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def get_workers_info(self) -> JSONResponse: + try: + res = await (await self._get_supervisor_ref()).get_workers_info() + return JSONResponse(content=res) + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + async def get_supervisor_info(self) -> JSONResponse: + try: + res = await (await self._get_supervisor_ref()).get_supervisor_info() + return res + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + async def abort_cluster(self) -> JSONResponse: + import os + import signal + + try: + res = await (await self._get_supervisor_ref()).abort_cluster() + os.kill(os.getpid(), signal.SIGINT) + return JSONResponse(content={"result": res}) + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + def run( supervisor_address: str, diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 10d9ae8231..9a1bfe9df7 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -1324,3 +1324,33 @@ def abort_request(self, model_uid: str, request_id: str): response_data = response.json() return response_data + + def get_workers_info(self): + url = f"{self.base_url}/v1/workers" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get workers info, detail: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data + + def get_supervisor_info(self): + url = f"{self.base_url}/v1/supervisor" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get supervisor info, detail: {_get_error_string(response)}" + ) + response_json = response.json() + return response_json + + def abort_cluster(self): + url = f"{self.base_url}/v1/clusters" + response = requests.delete(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to abort cluster, detail: {_get_error_string(response)}" + ) + response_json = response.json() + return response_json diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index b905a3ba76..88f8991e16 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -14,6 +14,8 @@ import asyncio import itertools +import os +import signal import time import typing from dataclasses import dataclass @@ -217,6 +219,17 @@ async def __post_create__(self): model_version_infos, self.address ) + # Windows does not have signal handler + if os.name != "nt": + + async def signal_handler(): + os._exit(0) + + loop = asyncio.get_running_loop() + loop.add_signal_handler( + signal.SIGTERM, lambda: asyncio.create_task(signal_handler()) + ) + @typing.no_type_check async def get_cluster_device_info(self, detailed: bool = False) -> List: import psutil @@ -1153,6 +1166,34 @@ async def confirm_and_remove_model( ) return ret + async def get_workers_info(self) -> List[Dict[str, Any]]: + ret = [] + for worker in self._worker_address_to_worker.values(): + ret.append(await worker.get_workers_info()) + return ret + + async def get_supervisor_info(self) -> Dict[str, Any]: + ret = { + "supervisor_ip": self.address, + } + return ret + + async def trigger_exit(self) -> bool: + try: + os.kill(os.getpid(), signal.SIGTERM) + except Exception as e: + logger.info(f"trigger exit error: {e}") + return False + return True + + async def abort_cluster(self) -> bool: + ret = True + for worker in self._worker_address_to_worker.values(): + ret = ret and await worker.trigger_exit() + + ret = ret and await self.trigger_exit() + return ret + @staticmethod def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 14e9909f8c..303550d197 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -284,6 +284,14 @@ async def signal_handler(): async def __pre_destroy__(self): self._isolation.stop() + async def trigger_exit(self) -> bool: + try: + os.kill(os.getpid(), signal.SIGINT) + except Exception as e: + logger.info(f"trigger exit error: {e}") + return False + return True + @staticmethod def get_devices_count(): from ..device_utils import gpu_count @@ -863,6 +871,13 @@ async def confirm_and_remove_model(self, model_version: str) -> bool: ) return True + async def get_workers_info(self) -> Dict[str, Any]: + ret = { + "work-ip": self.address, + "models": await self.list_models(), + } + return ret + @staticmethod def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 56b6a61182..2fb84d95c9 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -1578,5 +1578,51 @@ def cal_model_mem( print(" total: %d MB (%d GB)" % (mem_info.total, total_mem_g)) +@cli.command( + "stop-cluster", + help="Stop a cluster using the Xinference framework with the given parameters.", +) +@click.option( + "--endpoint", + "-e", + type=str, + required=True, + help="Xinference endpoint.", +) +@click.option( + "--api-key", + "-ak", + default=None, + type=str, + help="API key for accessing the Xinference API with authorization.", +) +@click.option("--check", is_flag=True, help="Confirm the deletion of the cache.") +def stop_cluster(endpoint: str, api_key: Optional[str], check: bool): + endpoint = get_endpoint(endpoint) + client = RESTfulClient(base_url=endpoint, api_key=api_key) + if api_key is None: + client._set_token(get_stored_token(endpoint, client)) + + if not check: + click.echo( + f"This command will stop Xinference cluster in {endpoint}.", err=True + ) + supervisor_info = client.get_supervisor_info() + click.echo("Supervisor information: ") + click.echo(supervisor_info) + + workers_info = client.get_workers_info() + click.echo("Workers information:") + click.echo(workers_info) + + click.confirm("Continue?", abort=True) + try: + result = client.abort_cluster() + result = result.get("result") + click.echo(f"Cluster stopped: {result}") + except Exception as e: + click.echo(e) + + if __name__ == "__main__": cli()