Skip to content

Commit

Permalink
add launch worker ip for register
Browse files Browse the repository at this point in the history
  • Loading branch information
hainaweiben committed Jul 4, 2024
1 parent 9e091fc commit 51eaadd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
5 changes: 4 additions & 1 deletion xinference/core/cache_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def _update_file_location(data: Dict, origin_version_info: Dict):
assert isinstance(origin_version_info["model_file_location"], dict)
origin_version_info["model_file_location"].update(data)

def record_model_version(self, version_info: Dict[str, List[Dict]], address: str):
def record_model_version(
self, version_info: Dict[str, List[Dict]], address: str, worker_ip: str
):
self._map_address_to_file_location(version_info, address)
for model_name, model_versions in version_info.items():
if model_name not in self._model_name_to_version_info:
Expand All @@ -68,6 +70,7 @@ def record_model_version(self, version_info: Dict[str, List[Dict]], address: str
self._update_file_location(
version["model_file_location"], origin_version
)
origin_version["worker_ip"] = worker_ip

def update_cache_status(
self,
Expand Down
29 changes: 28 additions & 1 deletion xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ async def register_model(
try:
register_fn(model_spec, persist)
await self._cache_tracker_ref.record_model_version(
generate_fn(model_spec), self.address
generate_fn(model_spec), self.address, worker_ip
)
except Exception as e:
unregister_fn(model_spec.model_name, raise_error=False)
Expand Down Expand Up @@ -818,6 +818,33 @@ async def _launch_one_model(_replica_model_uid):
)
replica_gpu_idx = assign_replica_gpu(_replica_model_uid, gpu_idx)
nonlocal model_type

from ..model.audio.custom import get_user_defined_audios
from ..model.embedding.custom import get_user_defined_embeddings
from ..model.image.custom import get_user_defined_images
from ..model.llm import get_user_defined_llm_families
from ..model.rerank.custom import get_user_defined_reranks

model_functions = [
get_user_defined_llm_families,
get_user_defined_embeddings,
get_user_defined_images,
get_user_defined_audios,
get_user_defined_reranks,
]

for model_func in model_functions:
for model_spec in model_func():
if model_spec.model_name == model_name:
version_info = await self.get_model_versions(
model_spec.model_type, model_name
)
for version in version_info:
target_ip_worker_ref = version["worker_ip"]
logger.info(
f"register model should launch by worker_ip: {target_ip_worker_ref}"
)

worker_ref = (
target_ip_worker_ref
if target_ip_worker_ref is not None
Expand Down

0 comments on commit 51eaadd

Please sign in to comment.