Skip to content

Commit

Permalink
feat: Moved body creationto function
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Dec 14, 2023
1 parent 3a48e8c commit ca9db03
Show file tree
Hide file tree
Showing 30 changed files with 273 additions and 125 deletions.
1 change: 1 addition & 0 deletions ai21/clients/bedrock/resources/bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def create(
"frequencyPenalty": frequency_penalty or {},
"presencePenalty": presence_penalty or {},
"countPenalty": count_penalty or {},
**kwargs,
}
raw_response = self._invoke(model_id=model_id, body=body)

Expand Down
9 changes: 2 additions & 7 deletions ai21/clients/sagemaker/resources/sagemaker_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ def create(
mode: Optional[str] = None,
**kwargs,
) -> AnswerResponse:
params = {
"context": context,
"question": question,
"answerLength": answer_length,
"mode": mode,
}
response = self._invoke(params)
body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode, **kwargs)
response = self._invoke(body)

return self._json_to_response(response)
1 change: 1 addition & 0 deletions ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def create(
"frequencyPenalty": frequency_penalty,
"presencePenalty": presence_penalty,
"countPenalty": count_penalty,
**kwargs,
}
raw_response = self._invoke(body)

Expand Down
4 changes: 1 addition & 3 deletions ai21/clients/sagemaker/resources/sagemaker_gec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

class SageMakerGEC(SageMakerResource, GEC):
def create(self, text: str, **kwargs) -> GECResponse:
body = {
"text": text,
}
body = self._create_body(text=text, **kwargs)

response = self._invoke(body)

Expand Down
13 changes: 7 additions & 6 deletions ai21/clients/sagemaker/resources/sagemaker_paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ def create(
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
body = {
"text": text,
"style": style,
"startIndex": start_index,
"endIndex": end_index,
}
body = self._create_body(
text=text,
style=style,
start_index=start_index,
end_index=end_index,
**kwargs,
)
response = self._invoke(body=body)

return self._json_to_response(response)
15 changes: 8 additions & 7 deletions ai21/clients/sagemaker/resources/sagemaker_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ def create(
summary_method: Optional[str] = None,
**kwargs,
) -> SummarizeResponse:
params = {
"source": source,
"sourceType": source_type,
"focus": focus,
"summaryMethod": summary_method,
}
body = self._create_body(
source=source,
source_type=source_type,
focus=focus,
summary_method=summary_method,
**kwargs,
)

response = self._invoke(params)
response = self._invoke(body)

return self._json_to_response(response)
9 changes: 2 additions & 7 deletions ai21/clients/studio/resources/studio_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@ def create(
) -> AnswerResponse:
url = f"{self._client.get_base_url()}/{self._MODULE_NAME}"

params = {
"context": context,
"question": question,
"answerLength": answer_length,
"mode": mode,
}
body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode, **kwargs)

response = self._post(url=url, body=params)
response = self._post(url=url, body=body)

return self._json_to_response(response)
34 changes: 17 additions & 17 deletions ai21/clients/studio/resources/studio_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ def create(
count_penalty: Optional[Dict[str, Any]] = None,
**kwargs,
) -> ChatResponse:
# Make a chat request to the AI21 API. Returns the response either as a string or a AI21Chat object.
params = {
"model": model,
"system": system,
"messages": messages,
"temperature": temperature,
"maxTokens": max_tokens,
"minTokens": min_tokens,
"numResults": num_results,
"topP": top_p,
"topKReturn": top_k_returns,
"stopSequences": stop_sequences,
"frequencyPenalty": frequency_penalty,
"presencePenalty": presence_penalty,
"countPenalty": count_penalty,
}
body = self._create_body(
model=model,
messages=messages,
system=system,
num_results=num_results,
temperature=temperature,
max_tokens=max_tokens,
min_tokens=min_tokens,
top_p=top_p,
top_k_returns=top_k_returns,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
count_penalty=count_penalty,
**kwargs,
)
url = f"{self._client.get_base_url()}/{model}/{self._module_name}"
response = self._post(url=url, body=params)
response = self._post(url=url, body=body)
return self._json_to_response(response)
35 changes: 18 additions & 17 deletions ai21/clients/studio/resources/studio_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,22 @@ def create(
url = f"{url}/{custom_model}"

url = f"{url}/{self._module_name}"
body = {
"model": model,
"customModel": custom_model,
"experimentalModel": experimental_mode,
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": frequency_penalty,
"presencePenalty": presence_penalty,
"countPenalty": count_penalty,
"epoch": epoch,
}
body = self._create_body(
model=model,
prompt=prompt,
max_tokens=max_tokens,
num_results=num_results,
min_tokens=min_tokens,
temperature=temperature,
top_p=top_p,
top_k_return=top_k_return,
custom_model=custom_model,
experimental_mode=experimental_mode,
stop_sequences=stop_sequences,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
count_penalty=count_penalty,
epoch=epoch,
**kwargs,
)
return self._json_to_response(self._post(url=url, body=body))
15 changes: 8 additions & 7 deletions ai21/clients/studio/resources/studio_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ def create(
**kwargs,
) -> None:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = {
"dataset_id": dataset_id,
"model_name": model_name,
"model_type": model_type,
"learning_rate": learning_rate,
"num_epochs": num_epochs,
}
body = self._create_body(
dataset_id=dataset_id,
model_name=model_name,
model_type=model_type,
learning_rate=learning_rate,
num_epochs=num_epochs,
**kwargs,
)
self._post(url=url, body=body)

def list(self) -> List[CustomModelResponse]:
Expand Down
16 changes: 8 additions & 8 deletions ai21/clients/studio/resources/studio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def upload(
**kwargs,
):
files = {"dataset_file": open(file_path, "rb")}
body = {
"dataset_name": dataset_name,
"selected_columns": selected_columns,
"approve_whitespace_correction": approve_whitespace_correction,
"delete_long_rows": delete_long_rows,
"split_ratio": split_ratio,
}

body = self._create_body(
dataset_name=dataset_name,
selected_columns=selected_columns,
approve_whitespace_correction=approve_whitespace_correction,
delete_long_rows=delete_long_rows,
split_ratio=split_ratio,
**kwargs,
)
return self._post(
url=self._base_url(),
body=body,
Expand Down
3 changes: 2 additions & 1 deletion ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class StudioEmbed(StudioResource, Embed):
def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body={"texts": texts, "type": type})
body = self._create_body(texts=texts, type=type, **kwargs)
response = self._post(url=url, body=body)

return self._json_to_response(response)
4 changes: 1 addition & 3 deletions ai21/clients/studio/resources/studio_gec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

class StudioGEC(StudioResource, GEC):
def create(self, text: str, **kwargs) -> GECResponse:
body = {
"text": text,
}
body = self._create_body(text=text, **kwargs)
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)

Expand Down
4 changes: 2 additions & 2 deletions ai21/clients/studio/resources/studio_improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse:
raise EmptyMandatoryListException("types")

url = f"{self._client.get_base_url()}/{self._module_name}"

response = self._post(url=url, body={"text": text, "types": types})
body = self._create_body(text=text, types=types, **kwargs)
response = self._post(url=url, body=body)

return self._json_to_response(response)
13 changes: 7 additions & 6 deletions ai21/clients/studio/resources/studio_paraphrase.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ def create(
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
body = {
"text": text,
"style": style,
"startIndex": start_index,
"endIndex": end_index,
}
body = self._create_body(
text=text,
style=style,
start_index=start_index,
end_index=end_index,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)

Expand Down
6 changes: 1 addition & 5 deletions ai21/clients/studio/resources/studio_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

class StudioSegmentation(StudioResource, Segmentation):
def create(self, source: str, source_type: str):
body = {
"source": source,
"sourceType": source_type,
}

body = self._create_body(source=source, source_type=source_type)
url = f"{self._client.get_base_url()}/{self._module_name}"
raw_response = self._post(url=url, body=body)

Expand Down
15 changes: 8 additions & 7 deletions ai21/clients/studio/resources/studio_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ def create(
**kwargs,
) -> SummarizeResponse:
# Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object.
params = {
"source": source,
"sourceType": source_type,
"focus": focus,
"summaryMethod": summary_method,
}
body = self._create_body(
source=source,
source_type=source_type,
focus=focus,
summary_method=summary_method,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=params)
response = self._post(url=url, body=body)

return self._json_to_response(response)
11 changes: 6 additions & 5 deletions ai21/clients/studio/resources/studio_summarize_by_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ class StudioSummarizeBySegment(StudioResource, SummarizeBySegment):
def create(
self, source: str, source_type: str, *, focus: Optional[str] = None, **kwargs
) -> SummarizeBySegmentResponse:
body = {
"source": source,
"sourceType": source_type,
"focus": focus,
}
body = self._create_body(
source=source,
source_type=source_type,
focus=focus,
**kwargs,
)
url = f"{self._client.get_base_url()}/{self._module_name}"
response = self._post(url=url, body=body)
return self._json_to_response(response)
10 changes: 10 additions & 0 deletions ai21/resources/bases/answer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@ def create(

def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse:
return AnswerResponse.model_validate(json)

def _create_body(
self,
context: str,
question: str,
answer_length: Optional[str],
mode: Optional[str],
**kwargs,
) -> Dict[str, Any]:
return {"context": context, "question": question, "answerLength": answer_length, "mode": mode, **kwargs}
34 changes: 34 additions & 0 deletions ai21/resources/bases/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,37 @@ def create(

def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse:
return ChatResponse.model_validate(json)

def _create_body(
self,
model: str,
messages: List[Message],
system: str,
num_results: Optional[int] = 1,
temperature: Optional[float] = 0.7,
max_tokens: Optional[int] = 300,
min_tokens: Optional[int] = 0,
top_p: Optional[float] = 1.0,
top_k_returns: Optional[int] = 0,
stop_sequences: Optional[List[str]] = None,
frequency_penalty: Optional[Dict[str, Any]] = None,
presence_penalty: Optional[Dict[str, Any]] = None,
count_penalty: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
return {
"model": model,
"system": system,
"messages": messages,
"temperature": temperature,
"maxTokens": max_tokens,
"minTokens": min_tokens,
"numResults": num_results,
"topP": top_p,
"topKReturn": top_k_returns,
"stopSequences": stop_sequences,
"frequencyPenalty": frequency_penalty,
"presencePenalty": presence_penalty,
"countPenalty": count_penalty,
**kwargs,
}
Loading

0 comments on commit ca9db03

Please sign in to comment.