From ca9db03634576abdc8e358af60f941201a9ccbe8 Mon Sep 17 00:00:00 2001 From: Asaf Gardin Date: Thu, 14 Dec 2023 15:59:57 +0200 Subject: [PATCH] feat: Moved body creationto function --- .../bedrock/resources/bedrock_completion.py | 1 + .../sagemaker/resources/sagemaker_answer.py | 9 +--- .../resources/sagemaker_completion.py | 1 + .../sagemaker/resources/sagemaker_gec.py | 4 +- .../resources/sagemaker_paraphrase.py | 13 ++--- .../resources/sagemaker_summarize.py | 15 +++--- .../clients/studio/resources/studio_answer.py | 9 +--- ai21/clients/studio/resources/studio_chat.py | 34 ++++++------ .../studio/resources/studio_completion.py | 35 +++++++------ .../studio/resources/studio_custom_model.py | 15 +++--- .../studio/resources/studio_dataset.py | 16 +++--- ai21/clients/studio/resources/studio_embed.py | 3 +- ai21/clients/studio/resources/studio_gec.py | 4 +- .../studio/resources/studio_improvements.py | 4 +- .../studio/resources/studio_paraphrase.py | 13 ++--- .../studio/resources/studio_segmentation.py | 6 +-- .../studio/resources/studio_summarize.py | 15 +++--- .../resources/studio_summarize_by_segment.py | 11 ++-- ai21/resources/bases/answer_base.py | 10 ++++ ai21/resources/bases/chat_base.py | 34 ++++++++++++ ai21/resources/bases/completion_base.py | 52 +++++++++++++------ ai21/resources/bases/custom_model_base.py | 18 +++++++ ai21/resources/bases/dataset_base.py | 18 +++++++ ai21/resources/bases/embed_base.py | 3 ++ ai21/resources/bases/gec_base.py | 3 ++ ai21/resources/bases/improvements_base.py | 3 ++ ai21/resources/bases/paraphrase_base.py | 16 ++++++ ai21/resources/bases/segmentation_base.py | 3 ++ ai21/resources/bases/summarize_base.py | 16 ++++++ .../bases/summarize_by_segment_base.py | 14 +++++ 30 files changed, 273 insertions(+), 125 deletions(-) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index b7d47e67..a61753a8 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -34,6 +34,7 @@ def create( "frequencyPenalty": frequency_penalty or {}, "presencePenalty": presence_penalty or {}, "countPenalty": count_penalty or {}, + **kwargs, } raw_response = self._invoke(model_id=model_id, body=body) diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index 093f668c..2ccae63e 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -15,12 +15,7 @@ def create( mode: Optional[str] = None, **kwargs, ) -> AnswerResponse: - params = { - "context": context, - "question": question, - "answerLength": answer_length, - "mode": mode, - } - response = self._invoke(params) + body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode, **kwargs) + response = self._invoke(body) return self._json_to_response(response) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 409d2330..68ec8be9 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -33,6 +33,7 @@ def create( "frequencyPenalty": frequency_penalty, "presencePenalty": presence_penalty, "countPenalty": count_penalty, + **kwargs, } raw_response = self._invoke(body) diff --git a/ai21/clients/sagemaker/resources/sagemaker_gec.py b/ai21/clients/sagemaker/resources/sagemaker_gec.py index 3272d4ca..35b4dc58 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_gec.py +++ b/ai21/clients/sagemaker/resources/sagemaker_gec.py @@ -5,9 +5,7 @@ class SageMakerGEC(SageMakerResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: - body = { - "text": text, - } + body = self._create_body(text=text, **kwargs) response = self._invoke(body) diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py index ab7f8ec3..cec154ab 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py @@ -15,12 +15,13 @@ def create( end_index: Optional[int] = None, **kwargs, ) -> ParaphraseResponse: - body = { - "text": text, - "style": style, - "startIndex": start_index, - "endIndex": end_index, - } + body = self._create_body( + text=text, + style=style, + start_index=start_index, + end_index=end_index, + **kwargs, + ) response = self._invoke(body=body) return self._json_to_response(response) diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py index ea252107..555e1e0a 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ b/ai21/clients/sagemaker/resources/sagemaker_summarize.py @@ -17,13 +17,14 @@ def create( summary_method: Optional[str] = None, **kwargs, ) -> SummarizeResponse: - params = { - "source": source, - "sourceType": source_type, - "focus": focus, - "summaryMethod": summary_method, - } + body = self._create_body( + source=source, + source_type=source_type, + focus=focus, + summary_method=summary_method, + **kwargs, + ) - response = self._invoke(params) + response = self._invoke(body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 060292a0..6ba7e18b 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -17,13 +17,8 @@ def create( ) -> AnswerResponse: url = f"{self._client.get_base_url()}/{self._MODULE_NAME}" - params = { - "context": context, - "question": question, - "answerLength": answer_length, - "mode": mode, - } + body = self._create_body(context=context, question=question, answer_length=answer_length, mode=mode, **kwargs) - response = self._post(url=url, body=params) + response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index 4a5ef6f5..3ec47955 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -24,22 +24,22 @@ def create( count_penalty: Optional[Dict[str, Any]] = None, **kwargs, ) -> ChatResponse: - # Make a chat request to the AI21 API. Returns the response either as a string or a AI21Chat object. - params = { - "model": model, - "system": system, - "messages": messages, - "temperature": temperature, - "maxTokens": max_tokens, - "minTokens": min_tokens, - "numResults": num_results, - "topP": top_p, - "topKReturn": top_k_returns, - "stopSequences": stop_sequences, - "frequencyPenalty": frequency_penalty, - "presencePenalty": presence_penalty, - "countPenalty": count_penalty, - } + body = self._create_body( + model=model, + messages=messages, + system=system, + num_results=num_results, + temperature=temperature, + max_tokens=max_tokens, + min_tokens=min_tokens, + top_p=top_p, + top_k_returns=top_k_returns, + stop_sequences=stop_sequences, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + count_penalty=count_penalty, + **kwargs, + ) url = f"{self._client.get_base_url()}/{model}/{self._module_name}" - response = self._post(url=url, body=params) + response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index 8f17e86f..eb61d8d6 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -37,21 +37,22 @@ def create( url = f"{url}/{custom_model}" url = f"{url}/{self._module_name}" - body = { - "model": model, - "customModel": custom_model, - "experimentalModel": experimental_mode, - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty, - "presencePenalty": presence_penalty, - "countPenalty": count_penalty, - "epoch": epoch, - } + body = self._create_body( + model=model, + prompt=prompt, + max_tokens=max_tokens, + num_results=num_results, + min_tokens=min_tokens, + temperature=temperature, + top_p=top_p, + top_k_return=top_k_return, + custom_model=custom_model, + experimental_mode=experimental_mode, + stop_sequences=stop_sequences, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + count_penalty=count_penalty, + epoch=epoch, + **kwargs, + ) return self._json_to_response(self._post(url=url, body=body)) diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index acf0665e..7a50279e 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -17,13 +17,14 @@ def create( **kwargs, ) -> None: url = f"{self._client.get_base_url()}/{self._module_name}" - body = { - "dataset_id": dataset_id, - "model_name": model_name, - "model_type": model_type, - "learning_rate": learning_rate, - "num_epochs": num_epochs, - } + body = self._create_body( + dataset_id=dataset_id, + model_name=model_name, + model_type=model_type, + learning_rate=learning_rate, + num_epochs=num_epochs, + **kwargs, + ) self._post(url=url, body=body) def list(self) -> List[CustomModelResponse]: diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 69e0c6bf..3a2ce2e5 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -18,14 +18,14 @@ def upload( **kwargs, ): files = {"dataset_file": open(file_path, "rb")} - body = { - "dataset_name": dataset_name, - "selected_columns": selected_columns, - "approve_whitespace_correction": approve_whitespace_correction, - "delete_long_rows": delete_long_rows, - "split_ratio": split_ratio, - } - + body = self._create_body( + dataset_name=dataset_name, + selected_columns=selected_columns, + approve_whitespace_correction=approve_whitespace_correction, + delete_long_rows=delete_long_rows, + split_ratio=split_ratio, + **kwargs, + ) return self._post( url=self._base_url(), body=body, diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index 4cca4818..5f98ada6 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -8,6 +8,7 @@ class StudioEmbed(StudioResource, Embed): def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse: url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body={"texts": texts, "type": type}) + body = self._create_body(texts=texts, type=type, **kwargs) + response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index 72877532..2b0617a1 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -5,9 +5,7 @@ class StudioGEC(StudioResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: - body = { - "text": text, - } + body = self._create_body(text=text, **kwargs) url = f"{self._client.get_base_url()}/{self._module_name}" response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index e111a5fc..7f48ea47 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -12,7 +12,7 @@ def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: raise EmptyMandatoryListException("types") url = f"{self._client.get_base_url()}/{self._module_name}" - - response = self._post(url=url, body={"text": text, "types": types}) + body = self._create_body(text=text, types=types, **kwargs) + response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index bef0de0c..6e287fac 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -15,12 +15,13 @@ def create( end_index: Optional[int] = None, **kwargs, ) -> ParaphraseResponse: - body = { - "text": text, - "style": style, - "startIndex": start_index, - "endIndex": end_index, - } + body = self._create_body( + text=text, + style=style, + start_index=start_index, + end_index=end_index, + **kwargs, + ) url = f"{self._client.get_base_url()}/{self._module_name}" response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index 8c630048..f6e5ac34 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -4,11 +4,7 @@ class StudioSegmentation(StudioResource, Segmentation): def create(self, source: str, source_type: str): - body = { - "source": source, - "sourceType": source_type, - } - + body = self._create_body(source=source, source_type=source_type) url = f"{self._client.get_base_url()}/{self._module_name}" raw_response = self._post(url=url, body=body) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 753f0c39..f513b7b2 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -19,13 +19,14 @@ def create( **kwargs, ) -> SummarizeResponse: # Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object. - params = { - "source": source, - "sourceType": source_type, - "focus": focus, - "summaryMethod": summary_method, - } + body = self._create_body( + source=source, + source_type=source_type, + focus=focus, + summary_method=summary_method, + **kwargs, + ) url = f"{self._client.get_base_url()}/{self._module_name}" - response = self._post(url=url, body=params) + response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index a79f4732..8ecfa928 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -11,11 +11,12 @@ class StudioSummarizeBySegment(StudioResource, SummarizeBySegment): def create( self, source: str, source_type: str, *, focus: Optional[str] = None, **kwargs ) -> SummarizeBySegmentResponse: - body = { - "source": source, - "sourceType": source_type, - "focus": focus, - } + body = self._create_body( + source=source, + source_type=source_type, + focus=focus, + **kwargs, + ) url = f"{self._client.get_base_url()}/{self._module_name}" response = self._post(url=url, body=body) return self._json_to_response(response) diff --git a/ai21/resources/bases/answer_base.py b/ai21/resources/bases/answer_base.py index 2b193059..b6b007f0 100644 --- a/ai21/resources/bases/answer_base.py +++ b/ai21/resources/bases/answer_base.py @@ -20,3 +20,13 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse: return AnswerResponse.model_validate(json) + + def _create_body( + self, + context: str, + question: str, + answer_length: Optional[str], + mode: Optional[str], + **kwargs, + ) -> Dict[str, Any]: + return {"context": context, "question": question, "answerLength": answer_length, "mode": mode, **kwargs} diff --git a/ai21/resources/bases/chat_base.py b/ai21/resources/bases/chat_base.py index ff5de434..a4aec1de 100644 --- a/ai21/resources/bases/chat_base.py +++ b/ai21/resources/bases/chat_base.py @@ -36,3 +36,37 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse: return ChatResponse.model_validate(json) + + def _create_body( + self, + model: str, + messages: List[Message], + system: str, + num_results: Optional[int] = 1, + temperature: Optional[float] = 0.7, + max_tokens: Optional[int] = 300, + min_tokens: Optional[int] = 0, + top_p: Optional[float] = 1.0, + top_k_returns: Optional[int] = 0, + stop_sequences: Optional[List[str]] = None, + frequency_penalty: Optional[Dict[str, Any]] = None, + presence_penalty: Optional[Dict[str, Any]] = None, + count_penalty: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Dict[str, Any]: + return { + "model": model, + "system": system, + "messages": messages, + "temperature": temperature, + "maxTokens": max_tokens, + "minTokens": min_tokens, + "numResults": num_results, + "topP": top_p, + "topKReturn": top_k_returns, + "stopSequences": stop_sequences, + "frequencyPenalty": frequency_penalty, + "presencePenalty": presence_penalty, + "countPenalty": count_penalty, + **kwargs, + } diff --git a/ai21/resources/bases/completion_base.py b/ai21/resources/bases/completion_base.py index 73818c71..ec9eb78a 100644 --- a/ai21/resources/bases/completion_base.py +++ b/ai21/resources/bases/completion_base.py @@ -33,22 +33,40 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse: return CompletionsResponse.model_validate(json) - -class AWSCompletionAdapter: - @abstractmethod - def create( + def _create_body( self, - model_id: str, + model: str, prompt: str, - max_tokens: int = 64, - num_results: int = 1, - min_tokens=0, - temperature=0.7, - top_p=1, - top_k_return=0, - stop_sequences: Optional[List[str]] = (), - frequency_penalty: Optional[Dict[str, Any]] = {}, - presence_penalty: Optional[Dict[str, Any]] = {}, - count_penalty: Optional[Dict[str, Any]] = {}, - ) -> CompletionsResponse: - pass + max_tokens: Optional[int], + num_results: Optional[int], + min_tokens: Optional[int], + temperature: Optional[float], + top_p: Optional[int], + top_k_return: Optional[int], + custom_model: Optional[str], + experimental_mode: bool, + stop_sequences: Optional[List[str]], + frequency_penalty: Optional[Dict[str, Any]], + presence_penalty: Optional[Dict[str, Any]], + count_penalty: Optional[Dict[str, Any]], + epoch: Optional[int], + **kwargs, + ): + return { + "model": model, + "customModel": custom_model, + "experimentalModel": experimental_mode, + "prompt": prompt, + "maxTokens": max_tokens, + "numResults": num_results, + "minTokens": min_tokens, + "temperature": temperature, + "topP": top_p, + "topKReturn": top_k_return, + "stopSequences": stop_sequences or [], + "frequencyPenalty": frequency_penalty, + "presencePenalty": presence_penalty, + "countPenalty": count_penalty, + "epoch": epoch, + **kwargs, + } diff --git a/ai21/resources/bases/custom_model_base.py b/ai21/resources/bases/custom_model_base.py index f3a0f122..bef5d08e 100644 --- a/ai21/resources/bases/custom_model_base.py +++ b/ai21/resources/bases/custom_model_base.py @@ -30,3 +30,21 @@ def get(self, resource_id: str) -> CustomModelResponse: def _json_to_response(self, json: Dict[str, Any]) -> CustomModelResponse: return CustomModelResponse.model_validate(json) + + def _create_body( + self, + dataset_id: str, + model_name: str, + model_type: str, + learning_rate: Optional[float], + num_epochs: Optional[int], + **kwargs, + ) -> Dict[str, Any]: + return { + "dataset_id": dataset_id, + "model_name": model_name, + "model_type": model_type, + "learning_rate": learning_rate, + "num_epochs": num_epochs, + **kwargs, + } diff --git a/ai21/resources/bases/dataset_base.py b/ai21/resources/bases/dataset_base.py index 8a08fd97..ed492fd7 100644 --- a/ai21/resources/bases/dataset_base.py +++ b/ai21/resources/bases/dataset_base.py @@ -31,3 +31,21 @@ def get(self, dataset_pid: str): def _json_to_response(self, json: Dict[str, Any]) -> DatasetResponse: return DatasetResponse.model_validate(json) + + def _create_body( + self, + dataset_name: str, + selected_columns: Optional[str], + approve_whitespace_correction: Optional[bool], + delete_long_rows: Optional[bool], + split_ratio: Optional[float], + **kwargs, + ) -> Dict[str, Any]: + return { + "dataset_name": dataset_name, + "selected_columns": selected_columns, + "approve_whitespace_correction": approve_whitespace_correction, + "delete_long_rows": delete_long_rows, + "split_ratio": split_ratio, + **kwargs, + } diff --git a/ai21/resources/bases/embed_base.py b/ai21/resources/bases/embed_base.py index 18c0db76..63ca2349 100644 --- a/ai21/resources/bases/embed_base.py +++ b/ai21/resources/bases/embed_base.py @@ -13,3 +13,6 @@ def create(self, texts: List[str], *, type: Optional[str] = None, **kwargs) -> E def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse: return EmbedResponse.model_validate(json) + + def _create_body(self, texts: List[str], type: Optional[str], **kwargs) -> Dict[str, Any]: + return {"texts": texts, "type": type, **kwargs} diff --git a/ai21/resources/bases/gec_base.py b/ai21/resources/bases/gec_base.py index 011e3924..f3e9094f 100644 --- a/ai21/resources/bases/gec_base.py +++ b/ai21/resources/bases/gec_base.py @@ -13,3 +13,6 @@ def create(self, text: str, **kwargs) -> GECResponse: def _json_to_response(self, json: Dict[str, Any]) -> GECResponse: return GECResponse.model_validate(json) + + def _create_body(self, text: str, **kwargs) -> Dict[str, Any]: + return {"text": text, **kwargs} diff --git a/ai21/resources/bases/improvements_base.py b/ai21/resources/bases/improvements_base.py index d5d917bb..aa64472f 100644 --- a/ai21/resources/bases/improvements_base.py +++ b/ai21/resources/bases/improvements_base.py @@ -13,3 +13,6 @@ def create(self, text: str, types: List[str], **kwargs) -> ImprovementsResponse: def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse: return ImprovementsResponse.model_validate(json) + + def _create_body(self, text: str, types: List[str], **kwargs) -> Dict[str, Any]: + return {"text": text, "types": types, **kwargs} diff --git a/ai21/resources/bases/paraphrase_base.py b/ai21/resources/bases/paraphrase_base.py index e3327ebc..e5836e92 100644 --- a/ai21/resources/bases/paraphrase_base.py +++ b/ai21/resources/bases/paraphrase_base.py @@ -21,3 +21,19 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse: return ParaphraseResponse.model_validate(json) + + def _create_body( + self, + text: str, + style: Optional[str], + start_index: Optional[int], + end_index: Optional[int], + **kwargs, + ) -> Dict[str, Any]: + return { + "text": text, + "style": style, + "startIndex": start_index, + "endIndex": end_index, + **kwargs, + } diff --git a/ai21/resources/bases/segmentation_base.py b/ai21/resources/bases/segmentation_base.py index 5cbd8217..6861bccf 100644 --- a/ai21/resources/bases/segmentation_base.py +++ b/ai21/resources/bases/segmentation_base.py @@ -13,3 +13,6 @@ def create(self, source: str, source_type: str): def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse: return SegmentationResponse.model_validate(json) + + def _create_body(self, source: str, source_type: str, **kwargs) -> Dict[str, Any]: + return {"source": source, "sourceType": source_type, **kwargs} diff --git a/ai21/resources/bases/summarize_base.py b/ai21/resources/bases/summarize_base.py index 1479f681..aaac6aff 100644 --- a/ai21/resources/bases/summarize_base.py +++ b/ai21/resources/bases/summarize_base.py @@ -19,3 +19,19 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: return SummarizeResponse.model_validate(json) + + def _create_body( + self, + source: str, + source_type: str, + focus: Optional[str], + summary_method: Optional[str], + **kwargs, + ) -> Dict[str, Any]: + return { + "source": source, + "sourceType": source_type, + "focus": focus, + "summaryMethod": summary_method, + **kwargs, + } diff --git a/ai21/resources/bases/summarize_by_segment_base.py b/ai21/resources/bases/summarize_by_segment_base.py index 6e78f22c..9ac9833e 100644 --- a/ai21/resources/bases/summarize_by_segment_base.py +++ b/ai21/resources/bases/summarize_by_segment_base.py @@ -22,3 +22,17 @@ def create( def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse: return SummarizeBySegmentResponse.model_validate(json) + + def _create_body( + self, + source: str, + source_type: str, + focus: Optional[str], + **kwargs, + ) -> Dict[str, Any]: + return { + "source": source, + "sourceType": source_type, + "focus": focus, + **kwargs, + }