Skip to content

Commit

Permalink
fix: Pass env config to client ctor (#34)
Browse files Browse the repository at this point in the history
* fix: removed application and organization

* fix: tests
  • Loading branch information
asafgardin authored Jan 3, 2024
1 parent 6c9c0d0 commit 4d4ef71
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 41 deletions.
23 changes: 5 additions & 18 deletions ai21/ai21_http_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, Dict, Any, BinaryIO

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyError
from ai21.http_client import HttpClient
from ai21.version import VERSION
Expand All @@ -16,25 +15,19 @@ def __init__(
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[int] = None,
num_retries: Optional[int] = None,
organization: Optional[str] = None,
application: Optional[str] = None,
via: Optional[str] = None,
http_client: Optional[HttpClient] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
self._env_config = env_config
self._api_key = api_key or self._env_config.api_key
self._api_key = api_key

if not self._api_key:
raise MissingApiKeyError()

self._api_host = api_host or self._env_config.api_host
self._api_version = api_version or self._env_config.api_version
self._api_host = api_host
self._api_version = api_version
self._headers = headers
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
self._num_retries = num_retries or self._env_config.num_retries
self._organization = organization
self._application = application
self._timeout_sec = timeout_sec
self._num_retries = num_retries
self._via = via

headers = self._build_headers(passed_headers=headers)
Expand Down Expand Up @@ -69,12 +62,6 @@ def _init_http_client(self, http_client: Optional[HttpClient], headers: Dict[str
def _build_user_agent(self) -> str:
user_agent = f"ai21 studio SDK {VERSION}"

if self._organization is not None:
user_agent = f"{user_agent} organization: {self._organization}"

if self._application is not None:
user_agent = f"{user_agent} application: {self._application}"

if self._via is not None:
user_agent = f"{user_agent} via: {self._via}"

Expand Down
11 changes: 7 additions & 4 deletions ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Any, Dict

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_http_client import AI21HTTPClient
from ai21.clients.studio.resources.studio_answer import StudioAnswer
from ai21.clients.studio.resources.studio_chat import StudioChat
Expand Down Expand Up @@ -35,14 +36,16 @@ def __init__(
num_retries: Optional[int] = None,
via: Optional[str] = None,
http_client: Optional[HttpClient] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
**kwargs,
):
self._http_client = AI21HTTPClient(
api_key=api_key,
api_host=api_host,
api_key=api_key or env_config.api_key,
api_host=api_host or env_config.api_host,
api_version=env_config.api_version,
headers=headers,
timeout_sec=timeout_sec,
num_retries=num_retries,
timeout_sec=timeout_sec or env_config.timeout_sec,
num_retries=num_retries or env_config.num_retries,
via=via,
http_client=http_client,
)
Expand Down
23 changes: 4 additions & 19 deletions tests/unittests/test_ai21_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,14 @@ class TestAI21StudioClient:
@pytest.mark.parametrize(
ids=[
"when_pass_only_via__should_include_via_in_user_agent",
"when_pass_only_application__should_include_application_in_user_agent",
"when_pass_organization__should_include_organization_in_user_agent",
"when_pass_all_user_agent_relevant_params__should_include_them_in_user_agent",
],
argnames=["via", "application", "organization", "expected_user_agent"],
argnames=["via", "expected_user_agent"],
argvalues=[
("langchain", None, None, f"ai21 studio SDK {VERSION} via: langchain"),
(None, "studio", None, f"ai21 studio SDK {VERSION} application: studio"),
(None, None, "ai21", f"ai21 studio SDK {VERSION} organization: ai21"),
(
"langchain",
"studio",
"ai21",
f"ai21 studio SDK {VERSION} organization: ai21 application: studio via: langchain",
),
("langchain", f"ai21 studio SDK {VERSION} via: langchain"),
],
)
def test__build_headers__user_agent(
self, via: Optional[str], application: Optional[str], organization: Optional[str], expected_user_agent: str
):
client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via, application=application, organization=organization)
def test__build_headers__user_agent(self, via: Optional[str], expected_user_agent: str):
client = AI21HTTPClient(api_key=_DUMMY_API_KEY, via=via)
assert client._http_client._headers["User-Agent"] == expected_user_agent

def test__build_headers__authorization(self):
Expand All @@ -67,12 +54,10 @@ def test__build_headers__when_pass_headers__should_append(self):

@pytest.mark.parametrize(
ids=[
"when_api_host_is_not_set__should_return_default",
"when_api_host_is_set__should_return_set_value",
],
argnames=["api_host", "expected_api_host"],
argvalues=[
(None, "https://api.ai21.com/studio/v1"),
("http://test_host", "http://test_host/studio/v1"),
],
)
Expand Down

0 comments on commit 4d4ef71

Please sign in to comment.