From cb6e9ddeb11132e9dcecb4823d70111328365f1a Mon Sep 17 00:00:00 2001 From: Yun Jegal Date: Mon, 2 Sep 2024 08:40:37 +0900 Subject: [PATCH] Bugfix upstage embedding when initializing the UpstageEmbedding class (#15767) --- docs/docs/examples/embeddings/upstage.ipynb | 2 +- .../llama_index/embeddings/upstage/base.py | 6 +- .../pyproject.toml | 2 +- .../integration_tests/test_integrations.py | 76 +++++++++++++++++-- .../unit_tests/test_embeddings_upstage.py | 23 ++++-- 5 files changed, 90 insertions(+), 19 deletions(-) diff --git a/docs/docs/examples/embeddings/upstage.ipynb b/docs/docs/examples/embeddings/upstage.ipynb index e1252bee5acc2..0c00df9f5f69e 100644 --- a/docs/docs/examples/embeddings/upstage.ipynb +++ b/docs/docs/examples/embeddings/upstage.ipynb @@ -28,7 +28,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install llama-index-embeddings-upstage==0.1.0" + "%pip install llama-index-embeddings-upstage==0.2.1" ] }, { diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/llama_index/embeddings/upstage/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/llama_index/embeddings/upstage/base.py index cc6a996f6199c..ef8787279f1fd 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/llama_index/embeddings/upstage/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/llama_index/embeddings/upstage/base.py @@ -51,7 +51,7 @@ class UpstageEmbedding(OpenAIEmbedding): default_factory=dict, description="Additional kwargs for the Upstage API." ) - api_key: str = Field(alias="upstage_api_key", description="The Upstage API key.") + api_key: str = Field(description="The Upstage API key.") api_base: Optional[str] = Field( default=DEFAULT_UPSTAGE_API_BASE, description="The base URL for Upstage API." ) @@ -127,14 +127,14 @@ def __init__( def class_name(cls) -> str: return "UpstageEmbedding" - def _get_credential_kwargs(self) -> Dict[str, Any]: + def _get_credential_kwargs(self, is_async: bool = False) -> Dict[str, Any]: return { "api_key": self.api_key, "base_url": self.api_base, "max_retries": self.max_retries, "timeout": self.timeout, "default_headers": self.default_headers, - "http_client": self._http_client, + "http_client": self._async_http_client if is_async else self._http_client, } def _get_query_embedding(self, query: str) -> List[float]: diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/pyproject.toml index 4aa9f3c03439d..a317c34f183d0 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/pyproject.toml @@ -30,7 +30,7 @@ license = "MIT" name = "llama-index-embeddings-upstage" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.2.0" +version = "0.2.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/integration_tests/test_integrations.py b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/integration_tests/test_integrations.py index ccf7e6cc3825d..4ad83129230bb 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/integration_tests/test_integrations.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/integration_tests/test_integrations.py @@ -1,6 +1,15 @@ import os import pytest +from pytest_mock import MockerFixture + +MOCK_EMBEDDING_DATA = [1.0, 2.0, 3.0] +UPSTAGE_TEST_API_KEY = "UPSTAGE_TEST_API_KEY" + + +@pytest.fixture() +def setup_environment(monkeypatch): + monkeypatch.setenv("UPSTAGE_API_KEY", UPSTAGE_TEST_API_KEY) @pytest.fixture() @@ -14,50 +23,101 @@ def upstage_embedding(): return UpstageEmbedding() -def test_upstage_embedding_query_embedding(upstage_embedding): +def test_upstage_embedding_query_embedding( + mocker: MockerFixture, setup_environment, upstage_embedding +): query = "hello" + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._get_query_embedding" + ) + mock_openai_client.return_value = MOCK_EMBEDDING_DATA + embedding = upstage_embedding.get_query_embedding(query) assert isinstance(embedding, list) -async def test_upstage_embedding_async_query_embedding(upstage_embedding): +async def test_upstage_embedding_async_query_embedding( + mocker: MockerFixture, setup_environment, upstage_embedding +): query = "hello" + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._aget_query_embedding" + ) + mock_openai_client.return_value = MOCK_EMBEDDING_DATA + embedding = await upstage_embedding.aget_query_embedding(query) assert isinstance(embedding, list) -def test_upstage_embedding_text_embedding(upstage_embedding): +def test_upstage_embedding_text_embedding( + mocker: MockerFixture, setup_environment, upstage_embedding +): text = "hello" + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embedding" + ) + mock_openai_client.return_value = MOCK_EMBEDDING_DATA + embedding = upstage_embedding.get_text_embedding(text) assert isinstance(embedding, list) -async def test_upstage_embedding_async_text_embedding(upstage_embedding): +async def test_upstage_embedding_async_text_embedding( + mocker: MockerFixture, setup_environment, upstage_embedding +): text = "hello" + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._aget_text_embedding" + ) + mock_openai_client.return_value = MOCK_EMBEDDING_DATA + embedding = await upstage_embedding.aget_text_embedding(text) assert isinstance(embedding, list) -def test_upstage_embedding_text_embeddings(upstage_embedding): +def test_upstage_embedding_text_embeddings( + mocker: MockerFixture, setup_environment, upstage_embedding +): texts = ["hello", "world"] + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embeddings" + ) + mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * len(texts) + embeddings = upstage_embedding.get_text_embedding_batch(texts) assert isinstance(embeddings, list) assert len(embeddings) == len(texts) assert all(isinstance(embedding, list) for embedding in embeddings) -def test_upstage_embedding_text_embeddings_fail_large_batch(): +def test_upstage_embedding_text_embeddings_fail_large_batch( + mocker: MockerFixture, setup_environment +): + large_batch_size = 2049 UpstageEmbedding = pytest.importorskip( "llama_index.embeddings.upstage", reason="Cannot import UpstageEmbedding" ).UpstageEmbedding - texts = ["hello"] * 2049 + + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._get_text_embeddings" + ) + mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * large_batch_size + + texts = ["hello"] * large_batch_size with pytest.raises(ValueError): upstage_embedding = UpstageEmbedding(embed_batch_size=2049) upstage_embedding.get_text_embedding_batch(texts) -async def test_upstage_embedding_async_text_embeddings(upstage_embedding): +async def test_upstage_embedding_async_text_embeddings( + mocker: MockerFixture, setup_environment, upstage_embedding +): texts = ["hello", "world"] + mock_openai_client = mocker.patch( + "llama_index.embeddings.upstage.base.UpstageEmbedding._aget_text_embeddings" + ) + mock_openai_client.return_value = [MOCK_EMBEDDING_DATA] * len(texts) + embeddings = await upstage_embedding.aget_text_embedding_batch(texts) assert isinstance(embeddings, list) assert len(embeddings) == len(texts) diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/unit_tests/test_embeddings_upstage.py b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/unit_tests/test_embeddings_upstage.py index 0e5e64f84bb31..f73d68cda1d8c 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/unit_tests/test_embeddings_upstage.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-upstage/tests/unit_tests/test_embeddings_upstage.py @@ -1,6 +1,8 @@ import pytest from llama_index.core.base.embeddings.base import BaseEmbedding +UPSTAGE_TEST_API_KEY = "upstage_test_key" + @pytest.fixture() def upstage_embedding(): @@ -9,6 +11,11 @@ def upstage_embedding(): ).UpstageEmbedding +@pytest.fixture() +def setup_environment(monkeypatch): + monkeypatch.setenv("UPSTAGE_API_KEY", UPSTAGE_TEST_API_KEY) + + def test_upstage_embedding_class(upstage_embedding): names_of_base_classes = [b.__name__ for b in upstage_embedding.__mro__] assert BaseEmbedding.__name__ in names_of_base_classes @@ -20,11 +27,15 @@ def test_upstage_embedding_fail_wrong_model(upstage_embedding): def test_upstage_embedding_api_key_alias(upstage_embedding): - api_key = "test_key" - embedding1 = upstage_embedding(api_key=api_key) - embedding2 = upstage_embedding(upstage_api_key=api_key) - embedding3 = upstage_embedding(error_api_key=api_key) + embedding1 = upstage_embedding(api_key=UPSTAGE_TEST_API_KEY) + embedding2 = upstage_embedding(upstage_api_key=UPSTAGE_TEST_API_KEY) + embedding3 = upstage_embedding(error_api_key=UPSTAGE_TEST_API_KEY) - assert embedding1.api_key == api_key - assert embedding2.api_key == api_key + assert embedding1.api_key == UPSTAGE_TEST_API_KEY + assert embedding2.api_key == UPSTAGE_TEST_API_KEY assert embedding3.api_key == "" + + +def test_upstage_embedding_api_key_with_env(setup_environment, upstage_embedding): + embedding = upstage_embedding() + assert embedding.api_key == UPSTAGE_TEST_API_KEY