Skip to content

Commit

Permalink
Bugfix upstage embedding when initializing the UpstageEmbedding class (
Browse files Browse the repository at this point in the history
  • Loading branch information
freedom07 authored Sep 1, 2024
1 parent 5c6b8f9 commit cb6e9dd
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/docs/examples/embeddings/upstage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install llama-index-embeddings-upstage==0.1.0"
"%pip install llama-index-embeddings-upstage==0.2.1"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class UpstageEmbedding(OpenAIEmbedding):
default_factory=dict, description="Additional kwargs for the Upstage API."
)

api_key: str = Field(alias="upstage_api_key", description="The Upstage API key.")
api_key: str = Field(description="The Upstage API key.")
api_base: Optional[str] = Field(
default=DEFAULT_UPSTAGE_API_BASE, description="The base URL for Upstage API."
)
Expand Down Expand Up @@ -127,14 +127,14 @@ def __init__(
def class_name(cls) -> str:
return "UpstageEmbedding"

def _get_credential_kwargs(self) -> Dict[str, Any]:
def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]:
return {
"api_key": self.api_key,
"base_url": self.api_base,
"max_retries": self.max_retries,
"timeout": self.timeout,
"default_headers": self.default_headers,
"http_client": self._http_client,
"http_client": self._async_http_client if is_async else self._http_client,
}

def _get_query_embedding(self, query: str) -> List[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ license = "MIT"
name = "llama-index-embeddings-upstage"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.2.0"
version = "0.2.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import os

import pytest
from pytest_mock import MockerFixture

MOCK_EMBEDDING_DATA = [1.0, 2.0, 3.0]
UPSTAGE_TEST_API_KEY = "UPSTAGE_TEST_API_KEY"


@pytest.fixture()
def setup_environment(monkeypatch):
monkeypatch.setenv("UPSTAGE_API_KEY", UPSTAGE_TEST_API_KEY)


@pytest.fixture()
Expand All @@ -14,50 +23,101 @@ def upstage_embedding():
return UpstageEmbedding()


def test_upstage_embedding_query_embedding(upstage_embedding):
def test_upstage_embedding_query_embedding(
mocker: MockerFixture, setup_environment, upstage_embedding
):
query = "hello"
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._get_query_embedding"
)
mock_openai_client.return_value = MOCK_EMBEDDING_DATA

embedding = upstage_embedding.get_query_embedding(query)
assert isinstance(embedding, list)


async def test_upstage_embedding_async_query_embedding(upstage_embedding):
async def test_upstage_embedding_async_query_embedding(
mocker: MockerFixture, setup_environment, upstage_embedding
):
query = "hello"
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._aget_query_embedding"
)
mock_openai_client.return_value = MOCK_EMBEDDING_DATA

embedding = await upstage_embedding.aget_query_embedding(query)
assert isinstance(embedding, list)


def test_upstage_embedding_text_embedding(upstage_embedding):
def test_upstage_embedding_text_embedding(
mocker: MockerFixture, setup_environment, upstage_embedding
):
text = "hello"
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embedding"
)
mock_openai_client.return_value = MOCK_EMBEDDING_DATA

embedding = upstage_embedding.get_text_embedding(text)
assert isinstance(embedding, list)


async def test_upstage_embedding_async_text_embedding(upstage_embedding):
async def test_upstage_embedding_async_text_embedding(
mocker: MockerFixture, setup_environment, upstage_embedding
):
text = "hello"
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._aget_text_embedding"
)
mock_openai_client.return_value = MOCK_EMBEDDING_DATA

embedding = await upstage_embedding.aget_text_embedding(text)
assert isinstance(embedding, list)


def test_upstage_embedding_text_embeddings(upstage_embedding):
def test_upstage_embedding_text_embeddings(
mocker: MockerFixture, setup_environment, upstage_embedding
):
texts = ["hello", "world"]
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embeddings"
)
mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * len(texts)

embeddings = upstage_embedding.get_text_embedding_batch(texts)
assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in embeddings)


def test_upstage_embedding_text_embeddings_fail_large_batch():
def test_upstage_embedding_text_embeddings_fail_large_batch(
mocker: MockerFixture, setup_environment
):
large_batch_size = 2049
UpstageEmbedding = pytest.importorskip(
"llama_index.embeddings.upstage", reason="Cannot import UpstageEmbedding"
).UpstageEmbedding
texts = ["hello"] * 2049

mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embeddings"
)
mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * large_batch_size

texts = ["hello"] * large_batch_size
with pytest.raises(ValueError):
upstage_embedding = UpstageEmbedding(embed_batch_size=2049)
upstage_embedding.get_text_embedding_batch(texts)


async def test_upstage_embedding_async_text_embeddings(upstage_embedding):
async def test_upstage_embedding_async_text_embeddings(
mocker: MockerFixture, setup_environment, upstage_embedding
):
texts = ["hello", "world"]
mock_openai_client = mocker.patch(
"llama_index.embeddings.upstage.base.UpstageEmbedding._aget_text_embeddings"
)
mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * len(texts)

embeddings = await upstage_embedding.aget_text_embedding_batch(texts)
assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
from llama_index.core.base.embeddings.base import BaseEmbedding

UPSTAGE_TEST_API_KEY = "upstage_test_key"


@pytest.fixture()
def upstage_embedding():
Expand All @@ -9,6 +11,11 @@ def upstage_embedding():
).UpstageEmbedding


@pytest.fixture()
def setup_environment(monkeypatch):
monkeypatch.setenv("UPSTAGE_API_KEY", UPSTAGE_TEST_API_KEY)


def test_upstage_embedding_class(upstage_embedding):
names_of_base_classes = [b.__name__ for b in upstage_embedding.__mro__]
assert BaseEmbedding.__name__ in names_of_base_classes
Expand All @@ -20,11 +27,15 @@ def test_upstage_embedding_fail_wrong_model(upstage_embedding):


def test_upstage_embedding_api_key_alias(upstage_embedding):
api_key = "test_key"
embedding1 = upstage_embedding(api_key=api_key)
embedding2 = upstage_embedding(upstage_api_key=api_key)
embedding3 = upstage_embedding(error_api_key=api_key)
embedding1 = upstage_embedding(api_key=UPSTAGE_TEST_API_KEY)
embedding2 = upstage_embedding(upstage_api_key=UPSTAGE_TEST_API_KEY)
embedding3 = upstage_embedding(error_api_key=UPSTAGE_TEST_API_KEY)

assert embedding1.api_key == api_key
assert embedding2.api_key == api_key
assert embedding1.api_key == UPSTAGE_TEST_API_KEY
assert embedding2.api_key == UPSTAGE_TEST_API_KEY
assert embedding3.api_key == ""


def test_upstage_embedding_api_key_with_env(setup_environment, upstage_embedding):
embedding = upstage_embedding()
assert embedding.api_key == UPSTAGE_TEST_API_KEY

0 comments on commit cb6e9dd

Please sign in to comment.