Skip to content

Commit

Permalink
fix: format
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Dec 12, 2023
1 parent 908baf4 commit aef1314
Show file tree
Hide file tree
Showing 24 changed files with 58 additions and 140 deletions.
16 changes: 4 additions & 12 deletions ai21/ai21_studio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@ def __init__(

headers = self._build_headers(passed_headers=headers)

self.http_client = HttpClient(
timeout_sec=timeout_sec, num_retries=num_retries, headers=headers
)
self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers)

def _build_headers(
self, passed_headers: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"User-Agent": self._build_user_agent(),
Expand Down Expand Up @@ -71,12 +67,8 @@ def _build_user_agent(self) -> str:

return user_agent

def execute_http_request(
self, method: str, url: str, params: Optional[Dict] = None, files=None
):
return self.http_client.execute_http_request(
method=method, url=url, params=params, files=files
)
def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None):
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
12 changes: 3 additions & 9 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def __init__(
env_config: _AI21EnvConfig = AI21EnvConfig,
):
self._session = (
session.client(RUNTIME_NAME)
if session
else boto3.client(RUNTIME_NAME, region_name=env_config.aws_region)
session.client(RUNTIME_NAME) if session else boto3.client(RUNTIME_NAME, region_name=env_config.aws_region)
)
self.completion = resources.BedrockCompletion(self)

Expand All @@ -55,9 +53,7 @@ def invoke_model(self, model_id: str, input_json: str) -> Dict[str, Any]:
def _handle_client_error(self, client_exception: ClientError) -> None:
error_response = client_exception.response
error_message = error_response.get("Error", {}).get("Message", "")
status_code = error_response.get("ResponseMetadata", {}).get(
"HTTPStatusCode", None
)
status_code = error_response.get("ResponseMetadata", {}).get("HTTPStatusCode", None)
# As written in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_InvokeModel.html

if status_code == 403:
Expand All @@ -71,9 +67,7 @@ def _handle_client_error(self, client_exception: ClientError) -> None:

if status_code == 424:
error_message_template = re.compile(_ERROR_MSG_TEMPLATE)
model_status_code = int(
error_message_template.search(error_message).group(1)
)
model_status_code = int(error_message_template.search(error_message).group(1))
model_error_message = error_message_template.search(error_message).group(2)
handle_non_success_response(model_status_code, model_error_message)

Expand Down
18 changes: 4 additions & 14 deletions ai21/clients/sagemaker/ai21_sagemaker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,7 @@ def __init__(
)
self._env_config = env_config
self._session = (
session
if session
else boto3.client(
"sagemaker-runtime", region_name=self._env_config.aws_region
)
session if session else boto3.client("sagemaker-runtime", region_name=self._env_config.aws_region)
)
self._region = region or self._env_config.aws_region
self._endpoint_name = endpoint_name
Expand All @@ -81,25 +77,19 @@ def invoke_endpoint(
except ClientError as sm_client_error:
self._handle_client_error(sm_client_error)
except Exception as exception:
log_error(
f"Calling {self._endpoint_name} failed with Exception: {exception}"
)
log_error(f"Calling {self._endpoint_name} failed with Exception: {exception}")
raise exception

def _handle_client_error(self, client_exception: ClientError):
error_response = client_exception.response
error_message = error_response.get("Error", {}).get("Message", "")
status_code = error_response.get("ResponseMetadata", {}).get(
"HTTPStatusCode", None
)
status_code = error_response.get("ResponseMetadata", {}).get("HTTPStatusCode", None)
# According to https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html#API_runtime_InvokeEndpoint_Errors
if status_code == 400:
raise BadRequest(details=error_message)
if status_code == 424:
error_message_template = re.compile(_ERROR_MSG_TEMPLATE)
model_status_code = int(
error_message_template.search(error_message).group(1)
)
model_status_code = int(error_message_template.search(error_message).group(1))
model_error_message = error_message_template.search(error_message).group(2)
handle_non_success_response(model_status_code, model_error_message)
if status_code == 429 or status_code == 503:
Expand Down
4 changes: 1 addition & 3 deletions ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@


class StudioEmbed(StudioResource, Embed):
def create(
self, texts: List[str], type: Optional[str] = None, **kwargs
) -> EmbedResponse:
def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._invoke(url=url, body={"texts": texts, "type": type})

Expand Down
28 changes: 7 additions & 21 deletions ai21/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,9 @@ def requests_retry_session(session, retries=0):


class HttpClient:
def __init__(
self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None
):
self.timeout_sec = (
timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC
)
self.num_retries = (
num_retries if num_retries is not None else DEFAULT_NUM_RETRIES
)
def __init__(self, timeout_sec: int = None, num_retries: int = None, headers: Dict = None):
self.timeout_sec = timeout_sec if timeout_sec is not None else DEFAULT_TIMEOUT_SEC
self.num_retries = num_retries if num_retries is not None else DEFAULT_NUM_RETRIES
self.headers = headers if headers is not None else {}
self.apply_retry_policy = self.num_retries > 0

Expand Down Expand Up @@ -112,27 +106,19 @@ def execute_http_request(
auth=auth,
)
else:
response = session.request(
method, url, headers=headers, data=data, timeout=timeout, auth=auth
)
response = session.request(method, url, headers=headers, data=data, timeout=timeout, auth=auth)
except ConnectionError as connection_error:
log_error(
f"Calling {method} {url} failed with ConnectionError: {connection_error}"
)
log_error(f"Calling {method} {url} failed with ConnectionError: {connection_error}")
raise connection_error
except RetryError as retry_error:
log_error(
f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}"
)
log_error(f"Calling {method} {url} failed with RetryError after {self.num_retries} attempts: {retry_error}")
raise retry_error
except Exception as exception:
log_error(f"Calling {method} {url} failed with Exception: {exception}")
raise exception

if response.status_code != 200:
log_error(
f"Calling {method} {url} failed with a non-200 response code: {response.status_code}"
)
log_error(f"Calling {method} {url} failed with a non-200 response code: {response.status_code}")
handle_non_success_response(response.status_code, response.text)

return response.json()
8 changes: 2 additions & 6 deletions ai21/modules/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

class SageMaker:
@classmethod
def get_model_package_arn(
cls, model_name: str, region: str, version: str = LATEST_VERSION_STR
) -> str:
def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str:
_assert_model_package_exists(model_name=model_name, region=region)

client = AI21StudioClient(auth_required=False)
Expand All @@ -35,9 +33,7 @@ def get_model_package_arn(
arn = response["arn"]

if not arn:
raise ModelPackageDoesntExistException(
model_name=model_name, region=region, version=version
)
raise ModelPackageDoesntExistException(model_name=model_name, region=region, version=version)

return arn

Expand Down
4 changes: 1 addition & 3 deletions ai21/resources/bases/embed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ class Embed(ABC):
_module_name = "embed"

@abstractmethod
def create(
self, texts: List[str], *, type: Optional[str] = None, **kwargs
) -> EmbedResponse:
def create(self, texts: List[str], *, type: Optional[str] = None, **kwargs) -> EmbedResponse:
pass

def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse:
Expand Down
8 changes: 2 additions & 6 deletions ai21/resources/studio_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,10 @@ def _post(
)

def _get(self, url: str, params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
return self._client.execute_http_request(
method="GET", url=url, params=params or {}
)
return self._client.execute_http_request(method="GET", url=url, params=params or {})

def _put(self, url: str, body: Dict[str, Any] = None) -> Dict[str, Any]:
return self._client.execute_http_request(
method="PUT", url=url, params=body or {}
)
return self._client.execute_http_request(method="PUT", url=url, params=body or {})

def _delete(self, url: str) -> Dict[str, Any]:
return self._client.execute_http_request(
Expand Down
4 changes: 1 addition & 3 deletions examples/bedrock/completion_bedrock_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

prompt = "The following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Male.\n\nConversation:\nUser: Hi, had a question\nMax: Hi there, happy to help!\nUser: Is there no way to return a product? I got your blue T-Shirt size small but it doesn't fit.\nMax: I'm sorry to hear that. Unfortunately we don't have a return policy. \nUser: That's a shame. \nMax: Is there anything else i can do for you?\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you'll have the \"Blue & White\" t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: [email protected]\nMax: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\nMax: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\nMax: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\nMax: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following are important points about the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, I have a question for you"

response = AI21BedrockClient().completion.create(
prompt=prompt, max_tokens=1000, model_id=ai21.BedrockModelID.J2_MID_V1
)
response = AI21BedrockClient().completion.create(prompt=prompt, max_tokens=1000, model_id=ai21.BedrockModelID.J2_MID_V1)

print(response.completions[0].data.text)
print(response.prompt.tokens[0].text_range.start)
4 changes: 1 addition & 3 deletions examples/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
)

client = AI21Client()
response = client.completion.create(
prompt=prompt, max_tokens=2, model="j2-light", temperature=0
)
response = client.completion.create(prompt=prompt, max_tokens=2, model="j2-light", temperature=0)

print(response)
print(response.completions[0].data.text)
Expand Down
4 changes: 1 addition & 3 deletions examples/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

client = AI21Client()
response = client.embed.create(
texts=[
"Holland is a geographical region[2] and former province on the western coast of the Netherlands."
],
texts=["Holland is a geographical region[2] and former province on the western coast of the Netherlands."],
type="segment",
)
print("embed: ", response.results[0].embedding)
32 changes: 17 additions & 15 deletions examples/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,23 @@


def create_file(path: str, name: str):
content = ("Sloths are a Neotropical group of xenarthran mammals constituting the suborder Folivora,"
" including the extant arboreal tree sloths and extinct terrestrial ground sloths."
" Noted for their slowness of movement, tree sloths spend most of their lives hanging "
"upside down in the trees of the tropical rainforests of South America and Central America."
" Sloths are considered to be most closely related to anteaters, together making up the"
" xenarthran order Pilosa. There are six extant sloth species in two genera – Bradypus"
" (three–toed sloths) and Choloepus (two–toed sloths). Despite this traditional naming,"
" all sloths have three toes on each rear limb-- although two-toed sloths have only two digits"
" on each forelimb.[3] The two groups of sloths are from different, distantly related families,"
" and are thought to have evolved their morphology via parallel evolution from terrestrial ancestors."
" Besides the extant species, many species of ground sloths ranging up to the size of elephants"
" (like Megatherium) inhabited both North and South America during the Pleistocene Epoch. However,"
" they became extinct during the Quaternary extinction event around 12,000 years ago,"
" along with most large bodied animals in the New World. The extinction correlates in time "
"with the arrival of humans, but climate change has also been suggested to have contributed. Members of an endemic radiation of Caribbean sloths also formerly lived in the Greater Antilles but became extinct after humans settled the archipelago in the mid-Holocene, around 6,000 years ago.")
content = (
"Sloths are a Neotropical group of xenarthran mammals constituting the suborder Folivora,"
" including the extant arboreal tree sloths and extinct terrestrial ground sloths."
" Noted for their slowness of movement, tree sloths spend most of their lives hanging "
"upside down in the trees of the tropical rainforests of South America and Central America."
" Sloths are considered to be most closely related to anteaters, together making up the"
" xenarthran order Pilosa. There are six extant sloth species in two genera – Bradypus"
" (three–toed sloths) and Choloepus (two–toed sloths). Despite this traditional naming,"
" all sloths have three toes on each rear limb-- although two-toed sloths have only two digits"
" on each forelimb.[3] The two groups of sloths are from different, distantly related families,"
" and are thought to have evolved their morphology via parallel evolution from terrestrial ancestors."
" Besides the extant species, many species of ground sloths ranging up to the size of elephants"
" (like Megatherium) inhabited both North and South America during the Pleistocene Epoch. However,"
" they became extinct during the Quaternary extinction event around 12,000 years ago,"
" along with most large bodied animals in the New World. The extinction correlates in time "
"with the arrival of humans, but climate change has also been suggested to have contributed. Members of an endemic radiation of Caribbean sloths also formerly lived in the Greater Antilles but became extinct after humans settled the archipelago in the mid-Holocene, around 6,000 years ago."
)
f = open(os.path.join(path, name), "w")
f.write(content)
f.close()
Expand Down
4 changes: 1 addition & 3 deletions examples/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ def validate_file_deleted():
print(uploaded_file.labels)
print(uploaded_file.public_url)

client.library.files.update(
file_id, publicUrl="www.example-updated.com", labels=["label3", "label4"]
)
client.library.files.update(file_id, publicUrl="www.example-updated.com", labels=["label3", "label4"])
updated_file = client.library.files.get(file_id)
print(updated_file.name)
print(updated_file.public_url)
Expand Down
4 changes: 1 addition & 3 deletions examples/paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@


client = AI21Client()
response = client.paraphrase.create(
text="The cat (Felis catus) is a domestic species of small carnivorous mammal"
)
response = client.paraphrase.create(text="The cat (Felis catus) is a domestic species of small carnivorous mammal")

print(response.suggestions[0].text)
print(response.suggestions[1].text)
Expand Down
4 changes: 1 addition & 3 deletions examples/sagemaker/answer_sm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient

client = AI21SageMakerClient(
endpoint_name="j2-quantization-mid-reach-dev-cve-version-12-202313"
)
client = AI21SageMakerClient(endpoint_name="j2-quantization-mid-reach-dev-cve-version-12-202313")

response = client.answer.create(
context="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2] From the "
Expand Down
Loading

0 comments on commit aef1314

Please sign in to comment.