Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not push] dell TGI adapter #259

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ distribution_spec:
- remote::fireworks
safety: meta-reference
agents: meta-reference
memory: meta-reference
memory:
- remote::chromadb
- meta-reference
telemetry: meta-reference
image_type: docker
25 changes: 21 additions & 4 deletions llama_stack/providers/adapters/inference/tgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,32 @@

from typing import Union

from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .tgi import InferenceAPIAdapter, InferenceEndpointAdapter, TGIAdapter
from .config import (
DellTGIImplConfig,
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
)
from .tgi import (
DellTGIAdapter,
InferenceAPIAdapter,
InferenceEndpointAdapter,
TGIAdapter,
)


async def get_adapter_impl(
config: Union[InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig],
config: Union[
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
DellTGIImplConfig,
],
_deps,
):
if isinstance(config, TGIImplConfig):
if isinstance(config, DellTGIImplConfig):
impl = DellTGIAdapter()
elif isinstance(config, TGIImplConfig):
impl = TGIAdapter()
elif isinstance(config, InferenceAPIImplConfig):
impl = InferenceAPIAdapter()
Expand Down
14 changes: 14 additions & 0 deletions llama_stack/providers/adapters/inference/tgi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,17 @@ class InferenceAPIImplConfig(BaseModel):
default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)",
)


@json_schema_type
class DellTGIImplConfig(BaseModel):
url: str = Field(
description="The URL for the Dell TGI endpoint (e.g. 'http://localhost:8080')",
)
hf_model_name: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
)
api_token: Optional[str] = Field(
default=None,
description="A bearer token if your TGI endpoint is protected.",
)
15 changes: 14 additions & 1 deletion llama_stack/providers/adapters/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
chat_completion_request_to_model_input_info,
)

from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
from .config import (
DellTGIImplConfig,
InferenceAPIImplConfig,
InferenceEndpointImplConfig,
TGIImplConfig,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -173,6 +178,14 @@ async def initialize(self, config: TGIImplConfig) -> None:
self.model_id = endpoint_info["model_id"]


class DellTGIAdapter(_HfAdapter):
async def initialize(self, config: DellTGIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.url, token=config.api_token)
endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = config.hf_model_name


class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(
Expand Down
9 changes: 9 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="dell-tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.adapters.inference.tgi",
config_class="llama_stack.providers.adapters.inference.tgi.DellTGIImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
Expand Down
49 changes: 49 additions & 0 deletions tests/examples/dell-tgi-run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: remote::dell-tgi
provider_type: remote::dell-tgi
config:
url: http://127.0.0.1:5009
hf_model_name: meta-llama/Llama-3.1-8B-Instruct
safety:
- provider_id: meta-reference
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: remote::chromadb
provider_type: remote::chromadb
config:
host: localhost
port: 6000
agents:
- provider_id: meta-reference
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta-reference
provider_type: meta-reference
config: {}