diff --git a/llama_stack/distribution/templates/build_configs/local-cpu-docker-build.yaml b/llama_stack/distribution/templates/build_configs/local-cpu-docker-build.yaml index 9db019454..c0fabbf4d 100644 --- a/llama_stack/distribution/templates/build_configs/local-cpu-docker-build.yaml +++ b/llama_stack/distribution/templates/build_configs/local-cpu-docker-build.yaml @@ -10,6 +10,8 @@ distribution_spec: - remote::fireworks safety: meta-reference agents: meta-reference - memory: meta-reference + memory: + - remote::chromadb + - meta-reference telemetry: meta-reference image_type: docker diff --git a/llama_stack/providers/adapters/inference/tgi/__init__.py b/llama_stack/providers/adapters/inference/tgi/__init__.py index 451650323..e3b24de2f 100644 --- a/llama_stack/providers/adapters/inference/tgi/__init__.py +++ b/llama_stack/providers/adapters/inference/tgi/__init__.py @@ -6,15 +6,32 @@ from typing import Union -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter +from .config import ( + DellTGIImplConfig, + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, +) +from .tgi import ( + DellTGIAdapter, + InferenceAPIAdapter, + InferenceEndpointAdapter, + TGIAdapter, +) async def get_adapter_impl( - config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig], + config: Union[ + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, + DellTGIImplConfig, + ], _deps, ): - if isinstance(config, TGIImplConfig): + if isinstance(config, DellTGIImplConfig): + impl = DellTGIAdapter() + elif isinstance(config, TGIImplConfig): impl = TGIAdapter() elif isinstance(config, InferenceAPIImplConfig): impl = InferenceAPIAdapter() diff --git a/llama_stack/providers/adapters/inference/tgi/config.py b/llama_stack/providers/adapters/inference/tgi/config.py index 6ce2b9dc6..801d5fc8f 100644 --- a/llama_stack/providers/adapters/inference/tgi/config.py +++ b/llama_stack/providers/adapters/inference/tgi/config.py @@ -41,3 +41,17 @@ class InferenceAPIImplConfig(BaseModel): default=None, description="Your Hugging Face user access token (will default to locally saved token if not provided)", ) + + +@json_schema_type +class DellTGIImplConfig(BaseModel): + url: str = Field( + description="The URL for the Dell TGI endpoint (e.g. 'http://localhost:8080')", + ) + hf_model_name: str = Field( + description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", + ) + api_token: Optional[str] = Field( + default=None, + description="A bearer token if your TGI endpoint is protected.", + ) diff --git a/llama_stack/providers/adapters/inference/tgi/tgi.py b/llama_stack/providers/adapters/inference/tgi/tgi.py index cd0afad0c..4fe160045 100644 --- a/llama_stack/providers/adapters/inference/tgi/tgi.py +++ b/llama_stack/providers/adapters/inference/tgi/tgi.py @@ -29,7 +29,12 @@ chat_completion_request_to_model_input_info, ) -from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig +from .config import ( + DellTGIImplConfig, + InferenceAPIImplConfig, + InferenceEndpointImplConfig, + TGIImplConfig, +) logger = logging.getLogger(__name__) @@ -173,6 +178,14 @@ async def initialize(self, config: TGIImplConfig) -> None: self.model_id = endpoint_info["model_id"] +class DellTGIAdapter(_HfAdapter): + async def initialize(self, config: DellTGIImplConfig) -> None: + self.client = AsyncInferenceClient(model=config.url, token=config.api_token) + endpoint_info = await self.client.get_endpoint_info() + self.max_tokens = endpoint_info["max_total_tokens"] + self.model_id = config.hf_model_name + + class InferenceAPIAdapter(_HfAdapter): async def initialize(self, config: InferenceAPIImplConfig) -> None: self.client = AsyncInferenceClient( diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 686fc273b..8530109c5 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -87,6 +87,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="dell-tgi", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.adapters.inference.tgi", + config_class="llama_stack.providers.adapters.inference.tgi.DellTGIImplConfig", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/tests/examples/dell-tgi-run.yaml b/tests/examples/dell-tgi-run.yaml new file mode 100644 index 000000000..5657dbb47 --- /dev/null +++ b/tests/examples/dell-tgi-run.yaml @@ -0,0 +1,49 @@ +version: '2' +built_at: '2024-10-08T17:40:45.325529' +image_name: local +docker_image: null +conda_env: local +apis: +- shields +- agents +- models +- memory +- memory_banks +- inference +- safety +providers: + inference: + - provider_id: remote::dell-tgi + provider_type: remote::dell-tgi + config: + url: http://127.0.0.1:5009 + hf_model_name: meta-llama/Llama-3.1-8B-Instruct + safety: + - provider_id: meta-reference + provider_type: meta-reference + config: + llama_guard_shield: + model: Llama-Guard-3-1B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + prompt_guard_shield: + model: Prompt-Guard-86M + memory: + - provider_id: remote::chromadb + provider_type: remote::chromadb + config: + host: localhost + port: 6000 + agents: + - provider_id: meta-reference + provider_type: meta-reference + config: + persistence_store: + namespace: null + type: sqlite + db_path: ~/.llama/runtime/kvstore.db + telemetry: + - provider_id: meta-reference + provider_type: meta-reference + config: {}