From 490f23dd7f885d4a459bbc7ae566fbfdf7df94cb Mon Sep 17 00:00:00 2001 From: Facundo Santiago Date: Sun, 1 Sep 2024 20:07:53 +0200 Subject: [PATCH] fix: GitHub Models metadata retrieval (#15747) --- docs/docs/examples/llm/azure_inference.ipynb | 41 ++++++++++++++++--- .../llama_index/llms/azure_inference/base.py | 25 ++++++++--- .../pyproject.toml | 2 +- .../tests/test_llms_azure_inference.py | 28 +++++++++++++ 4 files changed, 85 insertions(+), 11 deletions(-) diff --git a/docs/docs/examples/llm/azure_inference.ipynb b/docs/docs/examples/llm/azure_inference.ipynb index b8884d5a726f7..f16990fd9c5c2 100644 --- a/docs/docs/examples/llm/azure_inference.ipynb +++ b/docs/docs/examples/llm/azure_inference.ipynb @@ -15,7 +15,7 @@ "source": [ "# Azure AI model inference\n", "\n", - "This notebook explains how to use `llama-index-llm-azure-inference` package with models deployed with the Azure AI model inference API in Azure AI studio or Azure Machine Learning." + "This notebook explains how to use `llama-index-llm-azure-inference` package with models deployed with the Azure AI model inference API in Azure AI studio or Azure Machine Learning. The package also support GitHub Models (Preview) endpoints." ] }, { @@ -68,7 +68,8 @@ "3. Deploy one model supporting the [Azure AI model inference API](https://aka.ms/azureai/modelinference). In this example we use a `Mistral-Large` deployment. \n", "\n", " * You can follow the instructions at [Deploy models as serverless APIs](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless).\n", - "\n" + "\n", + "Alternatively, you can use GitHub Models endpoints with this integration, including the free tier experience. Read more about [GitHub models](https://github.com/marketplace/models)." ] }, { @@ -119,7 +120,7 @@ "id": "a593031b-c872-4360-8775-dff4844ccead", "metadata": {}, "source": [ - "## Use the model" + "## Connect to your deployment and endpoint" ] }, { @@ -186,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "97fb9877", + "id": "ed641e58", "metadata": {}, "source": [ "If you are planning to use asynchronous calling, it's a best practice to use the asynchronous version for the credentials:" @@ -195,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7bc2c98", + "id": "8aa5e256", "metadata": {}, "outputs": [], "source": [ @@ -209,6 +210,36 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "e3a6ad14", + "metadata": {}, + "source": [ + "If your endpoint is serving more than one model, like [GitHub Models](https://github.com/marketplace/models) or Azure AI Services, then you have to indicate the parameter `model_name`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f95b7416", + "metadata": {}, + "outputs": [], + "source": [ + "llm = AzureAICompletionsModel(\n", + " endpoint=os.environ[\"AZURE_INFERENCE_ENDPOINT\"],\n", + " credential=os.environ[\"AZURE_INFERENCE_CREDENTIAL\"],\n", + " model_name=\"mistral-large\", # change it to the model you want to use\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f414500c", + "metadata": {}, + "source": [ + "## Use the model" + ] + }, { "cell_type": "markdown", "id": "579ce31c-7b51-471e-bcb5-47da90b3d555", diff --git a/llama-index-integrations/llms/llama-index-llms-azure-inference/llama_index/llms/azure_inference/base.py b/llama-index-integrations/llms/llama-index-llms-azure-inference/llama_index/llms/azure_inference/base.py index 9f8f4d4ed6981..f5ffe962927fd 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-inference/llama_index/llms/azure_inference/base.py +++ b/llama-index-integrations/llms/llama-index-llms-azure-inference/llama_index/llms/azure_inference/base.py @@ -1,6 +1,7 @@ """Azure AI model inference chat completions client.""" import json +import logging from typing import ( Any, Callable, @@ -51,12 +52,15 @@ from azure.core.credentials import TokenCredential from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import HttpResponseError from azure.ai.inference.models import ( ChatCompletionsToolCall, ChatRequestMessage, ChatResponseMessage, ) +logger = logging.getLogger(__name__) + def to_inference_message( messages: Sequence[ChatMessage], @@ -279,11 +283,22 @@ def class_name(cls) -> str: @property def metadata(self) -> LLMMetadata: if not self._model_name: - model_info = self._client.get_model_info() - if model_info: - self._model_name = model_info.get("model_name", None) - self._model_type = model_info.get("model_type", None) - self._model_provider = model_info.get("model_provider_name", None) + try: + # Get model info from the endpoint. This method may not be supported by all + # endpoints. + model_info = self._client.get_model_info() + if model_info: + self._model_name = model_info.get("model_name", None) + self._model_type = model_info.get("model_type", None) + self._model_provider = model_info.get("model_provider_name", None) + except HttpResponseError: + logger.warning( + f"Endpoint '{self._client._config.endpoint}' does not support model metadata retrieval. " + "Failed to get model info for method `metadata()`." + ) + self._model_name = "unknown" + self._model_provider = "unknown" + self._model_type = "chat-completions" return LLMMetadata( is_chat_model=self._model_type == "chat-completions", diff --git a/llama-index-integrations/llms/llama-index-llms-azure-inference/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-azure-inference/pyproject.toml index 50d84e85c7061..bfe4a869e1214 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-inference/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-azure-inference/pyproject.toml @@ -28,7 +28,7 @@ license = "MIT" name = "llama-index-llms-azure-inference" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.2.1" +version = "0.2.2" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/llms/llama-index-llms-azure-inference/tests/test_llms_azure_inference.py b/llama-index-integrations/llms/llama-index-llms-azure-inference/tests/test_llms_azure_inference.py index 3ffd23de89c77..8fa0003b5be0f 100644 --- a/llama-index-integrations/llms/llama-index-llms-azure-inference/tests/test_llms_azure_inference.py +++ b/llama-index-integrations/llms/llama-index-llms-azure-inference/tests/test_llms_azure_inference.py @@ -1,3 +1,4 @@ +import logging import os import pytest import json @@ -5,6 +6,8 @@ from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.tools import FunctionTool +logger = logging.getLogger(__name__) + @pytest.mark.skipif( not { @@ -118,3 +121,28 @@ def echo(message: str) -> str: response.message.additional_kwargs["tool_calls"][0]["function"]["name"] == "echo" ) + + +@pytest.mark.skipif( + not { + "AZURE_INFERENCE_ENDPOINT", + "AZURE_INFERENCE_CREDENTIAL", + }.issubset(set(os.environ)), + reason="Azure AI endpoint and/or credential are not set.", +) +def test_get_metadata(caplog): + """Tests if we can get model metadata back from the endpoint. If so, + model_name should not be 'unknown'. Some endpoints may not support this + and in those cases a warning should be logged. + """ + # In case the endpoint being tested serves more than one model + model_name = os.environ.get("AZURE_INFERENCE_MODEL", None) + + llm = AzureAICompletionsModel(model_name=model_name) + + response = llm.metadata + + assert ( + response.model_name != "unknown" + or "does not support model metadata retrieval" in caplog.text + )