Skip to content

Commit

Permalink
add dimensions param to LiteLLMEmbedding, fix a bug that prevents…
Browse files Browse the repository at this point in the history
… reading vars from env (run-llama#15770)
  • Loading branch information
fcakyon authored Sep 1, 2024
1 parent 490f23d commit 5c6b8f9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
from typing import List
from typing import Any, List, Optional
from litellm import embedding

from llama_index.core.bridge.pydantic import Field
from llama_index.core.embeddings import BaseEmbedding


def get_embeddings(api_key: str, api_base: str, model_name: str, input: List[str]):
if not api_key:
# If key is not provided, we assume the consumer has configured
# their LiteLLM proxy server with their API key.
api_key = "some key"
def get_embeddings(
api_key: str, api_base: str, model_name: str, input: List[str], **kwargs: Any
) -> List[List[float]]:
"""
Retrieve embeddings for a given list of input strings using the specified model.
Args:
api_key (str): The API key for authentication.
api_base (str): The base URL of the LiteLLM proxy server.
model_name (str): The name of the model to use for generating embeddings.
input (List[str]): A list of input strings for which embeddings are to be generated.
**kwargs (Any): Additional keyword arguments to be passed to the embedding function.
Returns:
List[List[float]]: A list of embeddings, where each embedding corresponds to an input string.
"""
response = embedding(
api_key=api_key,
api_base=api_base,
model=model_name,
input=input,
**kwargs,
)
return [result["embedding"] for result in response.data]


class LiteLLMEmbedding(BaseEmbedding):
model_name: str = Field(
default="unknown", description="The name of the embedding model."
)
api_key: str = Field(
default="unknown",
model_name: str = Field(description="The name of the embedding model.")
api_key: Optional[str] = Field(
default=None,
description="OpenAI key. If not provided, the proxy server must be configured with the key.",
)
api_base: str = Field(
default="unknown", description="The base URL of the LiteLLM proxy."
api_base: Optional[str] = Field(
default=None, description="The base URL of the LiteLLM proxy."
)
dimensions: Optional[int] = Field(
default=None,
description=(
"The number of dimensions the resulting output embeddings should have. "
"Only supported in text-embedding-3 and later models."
),
)

@classmethod
Expand All @@ -47,6 +63,7 @@ def _get_query_embedding(self, query: str) -> List[float]:
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
input=[query],
)
return embeddings[0]
Expand All @@ -56,6 +73,7 @@ def _get_text_embedding(self, text: str) -> List[float]:
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
input=[text],
)
return embeddings[0]
Expand All @@ -65,5 +83,6 @@ def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
api_key=self.api_key,
api_base=self.api_base,
model_name=self.model_name,
dimensions=self.dimensions,
input=texts,
)
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-litellm"
readme = "README.md"
version = "0.2.0"
version = "0.2.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 5c6b8f9

Please sign in to comment.