From bd49dbc29f92a0fa12f15e6f0fbb30f748883845 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Fri, 21 Jun 2024 18:28:02 +0300 Subject: [PATCH 1/2] fix: Added azure api version in the url --- ai21/clients/azure/ai21_azure_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ai21/clients/azure/ai21_azure_client.py b/ai21/clients/azure/ai21_azure_client.py index b2d59aee..a802afa4 100644 --- a/ai21/clients/azure/ai21_azure_client.py +++ b/ai21/clients/azure/ai21_azure_client.py @@ -8,6 +8,7 @@ from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat AzureADTokenProvider = Callable[[], str] +_DEFAULT_AZURE_VERSION = "v1" class BaseAzureClient(ABC): @@ -47,6 +48,7 @@ class AsyncAI21AzureClient(BaseAzureClient, AsyncAI21HTTPClient): def __init__( self, base_url: str, + api_version: str = _DEFAULT_AZURE_VERSION, api_key: Optional[str] = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, @@ -63,6 +65,9 @@ def __init__( headers = self._prepare_headers(headers=default_headers or {}) + if api_version: + base_url += f"/{api_version}" + super().__init__( api_key=api_key, base_url=base_url, @@ -81,6 +86,7 @@ class AI21AzureClient(BaseAzureClient, AI21HTTPClient): def __init__( self, base_url: str, + api_version: str = _DEFAULT_AZURE_VERSION, api_key: Optional[str] = None, azure_ad_token: str | None = None, azure_ad_token_provider: AzureADTokenProvider | None = None, @@ -97,6 +103,9 @@ def __init__( headers = self._prepare_headers(headers=default_headers or {}) + if api_version: + base_url += f"/{api_version}" + super().__init__( api_key=api_key, base_url=base_url, From 1299bb2faecdbb42980b4c268fad2294619d2688 Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Fri, 21 Jun 2024 18:34:03 +0300 Subject: [PATCH 2/2] fix: Added _create_base_url --- ai21/clients/azure/ai21_azure_client.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ai21/clients/azure/ai21_azure_client.py b/ai21/clients/azure/ai21_azure_client.py index a802afa4..39f45ee4 100644 --- a/ai21/clients/azure/ai21_azure_client.py +++ b/ai21/clients/azure/ai21_azure_client.py @@ -43,6 +43,12 @@ def _get_azure_ad_token(self) -> Optional[str]: return None + def _add_version_to_url(self, base_url: str, api_version: str) -> str: + if api_version: + return f"{base_url}/{api_version}" + + return f"{base_url}/{_DEFAULT_AZURE_VERSION}" + class AsyncAI21AzureClient(BaseAzureClient, AsyncAI21HTTPClient): def __init__( @@ -64,9 +70,7 @@ def __init__( raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") headers = self._prepare_headers(headers=default_headers or {}) - - if api_version: - base_url += f"/{api_version}" + base_url = self._add_version_to_url(base_url=base_url, api_version=api_version) super().__init__( api_key=api_key, @@ -102,9 +106,7 @@ def __init__( raise ValueError("Must provide either api_key or azure_ad_token_provider or azure_ad_token") headers = self._prepare_headers(headers=default_headers or {}) - - if api_version: - base_url += f"/{api_version}" + base_url = self._add_version_to_url(base_url=base_url, api_version=api_version) super().__init__( api_key=api_key,