Skip to content

Commit

Permalink
Fix TGI register_model() issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Nov 23, 2024
1 parent 4b94cd3 commit 707da55
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions llama_stack/providers/remote/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from llama_stack.apis.models import * # noqa: F403

from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)

from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
Expand All @@ -37,37 +41,41 @@
log = logging.getLogger(__name__)


def build_model_aliases():
return [
build_model_alias(
model.huggingface_repo,
model.descriptor(),
)
for model in all_registered_models()
if model.huggingface_repo
]


class _HfAdapter(Inference, ModelsProtocolPrivate):
client: AsyncInferenceClient
max_tokens: int
model_id: str

def __init__(self) -> None:
self.formatter = ChatFormat(Tokenizer.get_instance())
self.register_helper = ModelRegistryHelper(build_model_aliases())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor()
for model in all_registered_models()
if model.huggingface_repo
}

async def register_model(self, model: Model) -> None:
async def shutdown(self) -> None:
pass

async def list_models(self) -> List[Model]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
Model(
identifier=identifier,
llama_model=identifier,
metadata={
"huggingface_repo": repo,
},
async def register_model(self, model: Model) -> None:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
)
]

async def shutdown(self) -> None:
pass
return model

async def unregister_model(self, model_id: str) -> None:
pass
Expand Down

0 comments on commit 707da55

Please sign in to comment.