diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-litellm/llama_index/embeddings/litellm/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-litellm/llama_index/embeddings/litellm/base.py index a373a43071121..3b468dc0362b5 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-litellm/llama_index/embeddings/litellm/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-litellm/llama_index/embeddings/litellm/base.py @@ -1,35 +1,51 @@ -from typing import List +from typing import Any, List, Optional from litellm import embedding from llama_index.core.bridge.pydantic import Field from llama_index.core.embeddings import BaseEmbedding -def get_embeddings(api_key: str, api_base: str, model_name: str, input: List[str]): - if not api_key: - # If key is not provided, we assume the consumer has configured - # their LiteLLM proxy server with their API key. - api_key = "some key" +def get_embeddings( + api_key: str, api_base: str, model_name: str, input: List[str], **kwargs: Any +) -> List[List[float]]: + """ + Retrieve embeddings for a given list of input strings using the specified model. + Args: + api_key (str): The API key for authentication. + api_base (str): The base URL of the LiteLLM proxy server. + model_name (str): The name of the model to use for generating embeddings. + input (List[str]): A list of input strings for which embeddings are to be generated. + **kwargs (Any): Additional keyword arguments to be passed to the embedding function. + + Returns: + List[List[float]]: A list of embeddings, where each embedding corresponds to an input string. + """ response = embedding( api_key=api_key, api_base=api_base, model=model_name, input=input, + **kwargs, ) return [result["embedding"] for result in response.data] class LiteLLMEmbedding(BaseEmbedding): - model_name: str = Field( - default="unknown", description="The name of the embedding model." - ) - api_key: str = Field( - default="unknown", + model_name: str = Field(description="The name of the embedding model.") + api_key: Optional[str] = Field( + default=None, description="OpenAI key. If not provided, the proxy server must be configured with the key.", ) - api_base: str = Field( - default="unknown", description="The base URL of the LiteLLM proxy." + api_base: Optional[str] = Field( + default=None, description="The base URL of the LiteLLM proxy." + ) + dimensions: Optional[int] = Field( + default=None, + description=( + "The number of dimensions the resulting output embeddings should have. " + "Only supported in text-embedding-3 and later models." + ), ) @classmethod @@ -47,6 +63,7 @@ def _get_query_embedding(self, query: str) -> List[float]: api_key=self.api_key, api_base=self.api_base, model_name=self.model_name, + dimensions=self.dimensions, input=[query], ) return embeddings[0] @@ -56,6 +73,7 @@ def _get_text_embedding(self, text: str) -> List[float]: api_key=self.api_key, api_base=self.api_base, model_name=self.model_name, + dimensions=self.dimensions, input=[text], ) return embeddings[0] @@ -65,5 +83,6 @@ def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: api_key=self.api_key, api_base=self.api_base, model_name=self.model_name, + dimensions=self.dimensions, input=texts, ) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-litellm/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-litellm/pyproject.toml index 07213990820a3..01099fd174f33 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-litellm/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-litellm/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-litellm" readme = "README.md" -version = "0.2.0" +version = "0.2.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"