From b39c92cfbdc88c9f98f3d5bfb20aa850f6211b78 Mon Sep 17 00:00:00 2001 From: amirai21 <89905406+amirai21@users.noreply.github.com> Date: Mon, 11 Nov 2024 14:30:09 +0200 Subject: [PATCH] feat: Sunset tsms and j2 (#225) feat!: Removed Legacy TSM Models --- .gitignore | 2 + README.md | 29 +- ai21/__init__.py | 24 +- ai21/clients/common/answer_base.py | 36 --- ai21/clients/common/completion_base.py | 15 +- ai21/clients/common/custom_model_base.py | 57 ---- ai21/clients/common/dataset_base.py | 60 ---- ai21/clients/common/embed_base.py | 23 -- ai21/clients/common/gec_base.py | 25 -- ai21/clients/common/improvements_base.py | 27 -- ai21/clients/common/paraphrase_base.py | 50 ---- ai21/clients/common/segmentation_base.py | 22 -- ai21/clients/common/summarize_base.py | 48 ---- .../common/summarize_by_segment_base.py | 41 --- ai21/clients/sagemaker/__init__.py | 0 .../sagemaker/ai21_sagemaker_client.py | 96 ------- ai21/clients/sagemaker/constants.py | 9 - ai21/clients/sagemaker/resources/__init__.py | 0 .../sagemaker/resources/sagemaker_answer.py | 32 --- .../resources/sagemaker_completion.py | 125 --------- .../sagemaker/resources/sagemaker_gec.py | 24 -- .../resources/sagemaker_paraphrase.py | 50 ---- .../sagemaker/resources/sagemaker_resource.py | 88 ------ .../resources/sagemaker_summarize.py | 54 ---- ai21/clients/sagemaker/sagemaker_session.py | 59 ---- ai21/clients/studio/ai21_client.py | 34 --- ai21/clients/studio/async_ai21_client.py | 22 -- .../clients/studio/resources/studio_answer.py | 30 -- .../studio/resources/studio_completion.py | 8 +- .../studio/resources/studio_custom_model.py | 68 ----- .../studio/resources/studio_dataset.py | 81 ------ ai21/clients/studio/resources/studio_embed.py | 22 -- ai21/clients/studio/resources/studio_gec.py | 20 -- .../studio/resources/studio_improvements.py | 26 -- .../studio/resources/studio_library.py | 106 +------ .../studio/resources/studio_paraphrase.py | 50 ---- .../studio/resources/studio_segmentation.py | 20 -- .../studio/resources/studio_summarize.py | 50 ---- .../resources/studio_summarize_by_segment.py | 46 --- ai21/models/__init__.py | 42 --- ai21/models/embed_type.py | 6 - ai21/models/improvement_type.py | 9 - ai21/models/paraphrase_style_type.py | 9 - ai21/models/responses/answer_response.py | 11 - .../models/responses/custom_model_response.py | 29 -- ai21/models/responses/dataset_response.py | 16 -- ai21/models/responses/embed_response.py | 15 - ai21/models/responses/gec_response.py | 28 -- ai21/models/responses/improvement_response.py | 18 -- .../responses/library_answer_response.py | 19 -- .../responses/library_search_response.py | 20 -- ai21/models/responses/paraphrase_response.py | 12 - .../models/responses/segmentation_response.py | 15 - .../summarize_by_segment_response.py | 25 -- ai21/models/responses/summarize_response.py | 6 - ai21/models/summary_method.py | 7 - ai21/services/__init__.py | 0 ai21/services/sagemaker.py | 69 ----- examples/sagemaker/answer.py | 14 - examples/sagemaker/async_answer.py | 21 -- examples/sagemaker/async_completion.py | 47 ---- examples/sagemaker/async_gec.py | 14 - examples/sagemaker/async_paraphrase.py | 17 -- examples/sagemaker/async_summarization.py | 23 -- examples/sagemaker/completion.py | 39 --- examples/sagemaker/gec.py | 6 - examples/sagemaker/get_model_package_arn.py | 4 - examples/sagemaker/paraphrase.py | 11 - examples/sagemaker/summarization.py | 15 - examples/studio/answer.py | 13 - examples/studio/async_answer.py | 21 -- examples/studio/async_completion.py | 84 ------ examples/studio/async_custom_model.py | 26 -- .../studio/async_custom_model_completion.py | 50 ---- examples/studio/async_dataset.py | 19 -- examples/studio/async_embed.py | 17 -- examples/studio/async_gec.py | 19 -- examples/studio/async_improvements.py | 24 -- examples/studio/async_library_answer.py | 13 - examples/studio/async_library_search.py | 13 - examples/studio/async_paraphrase.py | 22 -- examples/studio/async_segmentation.py | 23 -- examples/studio/async_summarize.py | 23 -- examples/studio/async_summarize_by_segment.py | 23 -- examples/studio/completion.py | 76 ----- examples/studio/custom_model.py | 19 -- examples/studio/custom_model_completion.py | 42 --- examples/studio/dataset.py | 11 - examples/studio/embed.py | 9 - examples/studio/gec.py | 11 - examples/studio/improvements.py | 16 -- examples/studio/library_answer.py | 5 - examples/studio/library_search.py | 5 - examples/studio/paraphrase.py | 14 - examples/studio/segmentation.py | 16 -- examples/studio/summarize.py | 15 - examples/studio/summarize_by_segment.py | 15 - .../clients/studio/test_answer.py | 61 ---- .../clients/studio/test_completion.py | 265 ------------------ .../clients/studio/test_embed.py | 63 ----- .../clients/studio/test_gec.py | 60 ---- .../clients/studio/test_improvements.py | 27 -- .../clients/studio/test_library_answer.py | 38 --- .../clients/studio/test_library_search.py | 25 -- .../clients/studio/test_paraphrase.py | 77 ----- .../clients/studio/test_segmentation.py | 100 ------- .../clients/studio/test_summarize.py | 113 -------- .../studio/test_summarize_by_segment.py | 122 -------- .../clients/test_sagemaker.py | 43 --- .../integration_tests/clients/test_studio.py | 58 ---- tests/integration_tests/services/__init__.py | 0 .../services/test_sagemaker.py | 36 --- .../clients/studio/resources/conftest.py | 247 ---------------- .../resources/test_async_studio_resource.py | 61 ---- .../studio/resources/test_completion.py | 12 - .../studio/resources/test_studio_resources.py | 61 ---- tests/unittests/models/response_mocks.py | 181 ------------ tests/unittests/models/test_serialization.py | 45 --- tests/unittests/services/__init__.py | 0 tests/unittests/services/sagemaker_stub.py | 12 - tests/unittests/services/test_sagemaker.py | 46 --- tests/unittests/test_imports.py | 19 +- 122 files changed, 17 insertions(+), 4545 deletions(-) delete mode 100644 ai21/clients/common/answer_base.py delete mode 100644 ai21/clients/common/custom_model_base.py delete mode 100644 ai21/clients/common/dataset_base.py delete mode 100644 ai21/clients/common/embed_base.py delete mode 100644 ai21/clients/common/gec_base.py delete mode 100644 ai21/clients/common/improvements_base.py delete mode 100644 ai21/clients/common/paraphrase_base.py delete mode 100644 ai21/clients/common/segmentation_base.py delete mode 100644 ai21/clients/common/summarize_base.py delete mode 100644 ai21/clients/common/summarize_by_segment_base.py delete mode 100644 ai21/clients/sagemaker/__init__.py delete mode 100644 ai21/clients/sagemaker/ai21_sagemaker_client.py delete mode 100644 ai21/clients/sagemaker/constants.py delete mode 100644 ai21/clients/sagemaker/resources/__init__.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_answer.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_completion.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_gec.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_paraphrase.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_resource.py delete mode 100644 ai21/clients/sagemaker/resources/sagemaker_summarize.py delete mode 100644 ai21/clients/sagemaker/sagemaker_session.py delete mode 100644 ai21/clients/studio/resources/studio_answer.py delete mode 100644 ai21/clients/studio/resources/studio_custom_model.py delete mode 100644 ai21/clients/studio/resources/studio_dataset.py delete mode 100644 ai21/clients/studio/resources/studio_embed.py delete mode 100644 ai21/clients/studio/resources/studio_gec.py delete mode 100644 ai21/clients/studio/resources/studio_improvements.py delete mode 100644 ai21/clients/studio/resources/studio_paraphrase.py delete mode 100644 ai21/clients/studio/resources/studio_segmentation.py delete mode 100644 ai21/clients/studio/resources/studio_summarize.py delete mode 100644 ai21/clients/studio/resources/studio_summarize_by_segment.py delete mode 100644 ai21/models/embed_type.py delete mode 100644 ai21/models/improvement_type.py delete mode 100644 ai21/models/paraphrase_style_type.py delete mode 100644 ai21/models/responses/answer_response.py delete mode 100644 ai21/models/responses/custom_model_response.py delete mode 100644 ai21/models/responses/dataset_response.py delete mode 100644 ai21/models/responses/embed_response.py delete mode 100644 ai21/models/responses/gec_response.py delete mode 100644 ai21/models/responses/improvement_response.py delete mode 100644 ai21/models/responses/library_answer_response.py delete mode 100644 ai21/models/responses/library_search_response.py delete mode 100644 ai21/models/responses/paraphrase_response.py delete mode 100644 ai21/models/responses/segmentation_response.py delete mode 100644 ai21/models/responses/summarize_by_segment_response.py delete mode 100644 ai21/models/responses/summarize_response.py delete mode 100644 ai21/models/summary_method.py delete mode 100644 ai21/services/__init__.py delete mode 100644 ai21/services/sagemaker.py delete mode 100644 examples/sagemaker/answer.py delete mode 100644 examples/sagemaker/async_answer.py delete mode 100644 examples/sagemaker/async_completion.py delete mode 100644 examples/sagemaker/async_gec.py delete mode 100644 examples/sagemaker/async_paraphrase.py delete mode 100644 examples/sagemaker/async_summarization.py delete mode 100644 examples/sagemaker/completion.py delete mode 100644 examples/sagemaker/gec.py delete mode 100644 examples/sagemaker/get_model_package_arn.py delete mode 100644 examples/sagemaker/paraphrase.py delete mode 100644 examples/sagemaker/summarization.py delete mode 100644 examples/studio/answer.py delete mode 100644 examples/studio/async_answer.py delete mode 100644 examples/studio/async_completion.py delete mode 100644 examples/studio/async_custom_model.py delete mode 100644 examples/studio/async_custom_model_completion.py delete mode 100644 examples/studio/async_dataset.py delete mode 100644 examples/studio/async_embed.py delete mode 100644 examples/studio/async_gec.py delete mode 100644 examples/studio/async_improvements.py delete mode 100644 examples/studio/async_library_answer.py delete mode 100644 examples/studio/async_library_search.py delete mode 100644 examples/studio/async_paraphrase.py delete mode 100644 examples/studio/async_segmentation.py delete mode 100644 examples/studio/async_summarize.py delete mode 100644 examples/studio/async_summarize_by_segment.py delete mode 100644 examples/studio/completion.py delete mode 100644 examples/studio/custom_model.py delete mode 100644 examples/studio/custom_model_completion.py delete mode 100644 examples/studio/dataset.py delete mode 100644 examples/studio/embed.py delete mode 100644 examples/studio/gec.py delete mode 100644 examples/studio/improvements.py delete mode 100644 examples/studio/library_answer.py delete mode 100644 examples/studio/library_search.py delete mode 100644 examples/studio/paraphrase.py delete mode 100644 examples/studio/segmentation.py delete mode 100644 examples/studio/summarize.py delete mode 100644 examples/studio/summarize_by_segment.py delete mode 100644 tests/integration_tests/clients/studio/test_answer.py delete mode 100644 tests/integration_tests/clients/studio/test_completion.py delete mode 100644 tests/integration_tests/clients/studio/test_embed.py delete mode 100644 tests/integration_tests/clients/studio/test_gec.py delete mode 100644 tests/integration_tests/clients/studio/test_improvements.py delete mode 100644 tests/integration_tests/clients/studio/test_library_answer.py delete mode 100644 tests/integration_tests/clients/studio/test_library_search.py delete mode 100644 tests/integration_tests/clients/studio/test_paraphrase.py delete mode 100644 tests/integration_tests/clients/studio/test_segmentation.py delete mode 100644 tests/integration_tests/clients/studio/test_summarize.py delete mode 100644 tests/integration_tests/clients/studio/test_summarize_by_segment.py delete mode 100644 tests/integration_tests/clients/test_sagemaker.py delete mode 100644 tests/integration_tests/services/__init__.py delete mode 100644 tests/integration_tests/services/test_sagemaker.py delete mode 100644 tests/unittests/clients/studio/resources/test_completion.py delete mode 100644 tests/unittests/services/__init__.py delete mode 100644 tests/unittests/services/sagemaker_stub.py delete mode 100644 tests/unittests/services/test_sagemaker.py diff --git a/.gitignore b/.gitignore index ffd1b91f..86f7313e 100644 --- a/.gitignore +++ b/.gitignore @@ -119,3 +119,5 @@ fabric.properties *.orig *.rej __pycache__ + +tests/integration_tests/test-file/* diff --git a/README.md b/README.md index be1d8d8a..b3bbeb80 100644 --- a/README.md +++ b/README.md @@ -304,7 +304,7 @@ Note that jamba-instruct supports async and streaming as well. -For a more detailed example, see the completion [examples](examples/studio/completion.py). +For a more detailed example, see the completion [examples](examples/studio/chat/chat_completions.py). --- @@ -388,33 +388,8 @@ For a more detailed example, see the chat [sync](examples/studio/conversational_ --- -## More Models - -## TSMs - -AI21 Studio's Task-Specific Models offer a range of powerful tools. These models have been specifically designed for their respective tasks and provide high-quality results while optimizing efficiency. -The full documentation and guides can be found [here](https://docs.ai21.com/docs/task-specific). - -### Contextual Answers - -The `answer` API allows you to access our high-quality question answering model. - -```python -from ai21 import AI21Client - -client = AI21Client() -response = client.answer.create( - context="This is a text is for testing purposes", - question="Question about context", -) -``` - -A detailed explanation on Contextual Answers, can be found [here](https://docs.ai21.com/docs/contextual-answers-api) - ### File Upload ---- - ```python from ai21 import AI21Client @@ -430,8 +405,6 @@ file_id = client.library.files.create( uploaded_file = client.library.files.get(file_id) ``` -For more information on more Task Specific Models, see the [documentation](https://docs.ai21.com/reference/paraphrase-api-ref). - ## Token Counting --- diff --git a/ai21/__init__.py b/ai21/__init__.py index 8daa8569..0548e4ce 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -14,7 +14,6 @@ TooManyRequestsError, ) from ai21.logger import setup_logger -from ai21.services.sagemaker import SageMaker from ai21.version import VERSION __version__ = VERSION @@ -27,12 +26,6 @@ def _import_bedrock_client(): return AI21BedrockClient -def _import_sagemaker_client(): - from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient - - return AI21SageMakerClient - - def _import_bedrock_model_id(): from ai21.clients.bedrock.bedrock_model_id import BedrockModelID @@ -45,12 +38,6 @@ def _import_async_bedrock_client(): return AsyncAI21BedrockClient -def _import_async_sagemaker_client(): - from ai21.clients.sagemaker.ai21_sagemaker_client import AsyncAI21SageMakerClient - - return AsyncAI21SageMakerClient - - def _import_vertex_client(): from ai21.clients.vertex.ai21_vertex_client import AI21VertexClient @@ -68,25 +55,19 @@ def __getattr__(name: str) -> Any: if name == "AI21BedrockClient": return _import_bedrock_client() - if name == "AI21SageMakerClient": - return _import_sagemaker_client() - if name == "BedrockModelID": return _import_bedrock_model_id() if name == "AsyncAI21BedrockClient": return _import_async_bedrock_client() - if name == "AsyncAI21SageMakerClient": - return _import_async_sagemaker_client() - if name == "AI21VertexClient": return _import_vertex_client() if name == "AsyncAI21VertexClient": return _import_async_vertex_client() except ImportError as e: - raise ImportError('Please install "ai21[AWS]" for SageMaker or Bedrock, or "ai21[Vertex]" for Vertex') from e + raise ImportError('Please install "ai21[AWS]" for Bedrock, or "ai21[Vertex]" for Vertex') from e __all__ = [ @@ -100,13 +81,10 @@ def __getattr__(name: str) -> Any: "ModelPackageDoesntExistError", "TooManyRequestsError", "AI21BedrockClient", - "AI21SageMakerClient", "BedrockModelID", - "SageMaker", "AI21AzureClient", "AsyncAI21AzureClient", "AsyncAI21BedrockClient", - "AsyncAI21SageMakerClient", "AI21VertexClient", "AsyncAI21VertexClient", ] diff --git a/ai21/clients/common/answer_base.py b/ai21/clients/common/answer_base.py deleted file mode 100644 index b74c87c7..00000000 --- a/ai21/clients/common/answer_base.py +++ /dev/null @@ -1,36 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, cast - -from ai21.models import AnswerResponse -from ai21.models._pydantic_compatibility import _from_dict - - -class Answer(ABC): - _module_name = "answer" - - @abstractmethod - def create( - self, - context: str, - question: str, - **kwargs, - ) -> AnswerResponse: - """ - - :param context: A string containing the document context for which the question will be answered - :param question: A string containing the question to be answered based on the provided context. - :param kwargs: - :return: - """ - pass - - def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse: - return cast(_from_dict(obj=AnswerResponse, obj_dict=json), AnswerResponse) - - def _create_body( - self, - context: str, - question: str, - **kwargs, - ) -> Dict[str, Any]: - return {"context": context, "question": question, **kwargs} diff --git a/ai21/clients/common/completion_base.py b/ai21/clients/common/completion_base.py index 3796b81b..a57c9e1f 100644 --- a/ai21/clients/common/completion_base.py +++ b/ai21/clients/common/completion_base.py @@ -42,7 +42,6 @@ def create( temperature: float | NOT_GIVEN = NOT_GIVEN, top_p: float | NotGiven = NOT_GIVEN, top_k_return: int | NotGiven = NOT_GIVEN, - custom_model: str | NotGiven = NOT_GIVEN, stop_sequences: List[str] | NotGiven = NOT_GIVEN, frequency_penalty: Penalty | NotGiven = NOT_GIVEN, presence_penalty: Penalty | NotGiven = NOT_GIVEN, @@ -60,7 +59,6 @@ def create( :param temperature: A value controlling the "creativity" of the model's responses. :param top_p: A value controlling the diversity of the model's responses. :param top_k_return: The number of top-scoring tokens to consider for each generation step. - :param custom_model: :param stop_sequences: Stops decoding if any of the strings is generated :param frequency_penalty: A penalty applied to tokens that are frequently generated. :param presence_penalty: A penalty applied to tokens that are already present in the prompt. @@ -84,7 +82,6 @@ def _create_body( temperature: float | NotGiven, top_p: float | NotGiven, top_k_return: int | NotGiven, - custom_model: str | NotGiven, stop_sequences: List[str] | NotGiven, frequency_penalty: Penalty | NotGiven, presence_penalty: Penalty | NotGiven, @@ -96,7 +93,6 @@ def _create_body( return remove_not_given( { "model": model, - "customModel": custom_model, "prompt": prompt, "maxTokens": max_tokens, "numResults": num_results, @@ -114,12 +110,5 @@ def _create_body( } ) - def _get_completion_path(self, model: str, custom_model: str | NotGiven = NOT_GIVEN): - path = f"/{model}" - - if custom_model: - path = f"{path}/{custom_model}" - - path = f"{path}/{self._module_name}" - - return path + def _get_completion_path(self, model: str): + return f"/{model}/{self._module_name}" diff --git a/ai21/clients/common/custom_model_base.py b/ai21/clients/common/custom_model_base.py deleted file mode 100644 index 776626d4..00000000 --- a/ai21/clients/common/custom_model_base.py +++ /dev/null @@ -1,57 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, List, Any, Dict - -from ai21.models import CustomBaseModelResponse - - -class CustomModel(ABC): - _module_name = "custom-model" - - @abstractmethod - def create( - self, - dataset_id: str, - model_name: str, - model_type: str, - *, - learning_rate: Optional[float] = None, - num_epochs: Optional[int] = None, - **kwargs, - ) -> None: - """ - - :param dataset_id: The dataset you want to train your model on. - :param model_name: The name of your trained model - :param model_type: The type of model to train. - :param learning_rate: The learning rate used for training. - :param num_epochs: Number of epochs for training - :param kwargs: - :return: - """ - pass - - @abstractmethod - def list(self) -> List[CustomBaseModelResponse]: - pass - - @abstractmethod - def get(self, resource_id: str) -> CustomBaseModelResponse: - pass - - def _create_body( - self, - dataset_id: str, - model_name: str, - model_type: str, - learning_rate: Optional[float], - num_epochs: Optional[int], - **kwargs, - ) -> Dict[str, Any]: - return { - "dataset_id": dataset_id, - "model_name": model_name, - "model_type": model_type, - "learning_rate": learning_rate, - "num_epochs": num_epochs, - **kwargs, - } diff --git a/ai21/clients/common/dataset_base.py b/ai21/clients/common/dataset_base.py deleted file mode 100644 index 617f75cd..00000000 --- a/ai21/clients/common/dataset_base.py +++ /dev/null @@ -1,60 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Any, Dict - - -class Dataset(ABC): - _module_name = "dataset" - - @abstractmethod - def create( - self, - file_path: str, - dataset_name: str, - *, - selected_columns: Optional[str] = None, - approve_whitespace_correction: Optional[bool] = None, - delete_long_rows: Optional[bool] = None, - split_ratio: Optional[float] = None, - **kwargs, - ): - """ - - :param file_path: Local path to dataset - :param dataset_name: Dataset name. Must be unique - :param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns. - :param approve_whitespace_correction: Automatically correct examples that violate best practices - :param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens - :param split_ratio: - :param kwargs: - :return: - """ - pass - - @abstractmethod - def list(self): - pass - - @abstractmethod - def get(self, dataset_pid: str): - pass - - def _create_body( - self, - dataset_name: str, - selected_columns: Optional[str], - approve_whitespace_correction: Optional[bool], - delete_long_rows: Optional[bool], - split_ratio: Optional[float], - **kwargs, - ) -> Dict[str, Any]: - return { - "dataset_name": dataset_name, - "selected_columns": selected_columns, - "approve_whitespace_correction": approve_whitespace_correction, - "delete_long_rows": delete_long_rows, - "split_ratio": split_ratio, - **kwargs, - } - - def _base_url(self, base_url: str) -> str: - return f"{base_url}/{self._module_name}" diff --git a/ai21/clients/common/embed_base.py b/ai21/clients/common/embed_base.py deleted file mode 100644 index 6e16a795..00000000 --- a/ai21/clients/common/embed_base.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Any, Dict, Optional - -from ai21.models import EmbedType, EmbedResponse - - -class Embed(ABC): - _module_name = "embed" - - @abstractmethod - def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: - """ - - :param texts: A list of strings, each representing a document or segment of text to be embedded. - :param type: For retrieval/search use cases, indicates whether the texts that were - sent are segments or the query. - :param kwargs: - :return: - """ - pass - - def _create_body(self, texts: List[str], type: Optional[str], **kwargs) -> Dict[str, Any]: - return {"texts": texts, "type": type, **kwargs} diff --git a/ai21/clients/common/gec_base.py b/ai21/clients/common/gec_base.py deleted file mode 100644 index 9855acb0..00000000 --- a/ai21/clients/common/gec_base.py +++ /dev/null @@ -1,25 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, Any, cast - -from ai21.models import GECResponse -from ai21.models._pydantic_compatibility import _from_dict - - -class GEC(ABC): - _module_name = "gec" - - @abstractmethod - def create(self, text: str, **kwargs) -> GECResponse: - """ - - :param text: The input text to be corrected. - :param kwargs: - :return: - """ - pass - - def _json_to_response(self, json: Dict[str, Any]) -> GECResponse: - return cast(_from_dict(obj=GECResponse, obj_dict=json), GECResponse) - - def _create_body(self, text: str, **kwargs) -> Dict[str, Any]: - return {"text": text, **kwargs} diff --git a/ai21/clients/common/improvements_base.py b/ai21/clients/common/improvements_base.py deleted file mode 100644 index 47e21bbe..00000000 --- a/ai21/clients/common/improvements_base.py +++ /dev/null @@ -1,27 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List - -from ai21.errors import EmptyMandatoryListError -from ai21.models import ImprovementType, ImprovementsResponse - - -class Improvements(ABC): - _module_name = "improvements" - - @abstractmethod - def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: - """ - - :param text: The input text to be improved. - :param types: Types of improvements to apply. - :param kwargs: - :return: - """ - pass - - def _create_body(self, text: str, types: List[str], **kwargs) -> Dict[str, Any]: - return {"text": text, "types": types, **kwargs} - - def _validate_types(self, types: List[ImprovementType]): - if len(types) == 0: - raise EmptyMandatoryListError("types") diff --git a/ai21/clients/common/paraphrase_base.py b/ai21/clients/common/paraphrase_base.py deleted file mode 100644 index faf3a7a8..00000000 --- a/ai21/clients/common/paraphrase_base.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Any, Dict, cast - -from ai21.models import ParaphraseStyleType, ParaphraseResponse -from ai21.models._pydantic_compatibility import _from_dict - - -class Paraphrase(ABC): - _module_name = "paraphrase" - - @abstractmethod - def create( - self, - text: str, - *, - style: Optional[ParaphraseStyleType] = None, - start_index: Optional[int] = 0, - end_index: Optional[int] = None, - **kwargs, - ) -> ParaphraseResponse: - """ - - :param text: The input text to be paraphrased. - :param style: Controls length and tone - :param start_index: Specifies the starting position of the paraphrasing process in the given text - :param end_index: specifies the position of the last character to be paraphrased, including the character - following it. If the parameter is not provided, the default value is set to the length of the given text. - :param kwargs: - :return: - """ - pass - - def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse: - return cast(_from_dict(obj=ParaphraseResponse, obj_dict=json), ParaphraseResponse) - - def _create_body( - self, - text: str, - style: Optional[str], - start_index: Optional[int], - end_index: Optional[int], - **kwargs, - ) -> Dict[str, Any]: - return { - "text": text, - "style": style, - "startIndex": start_index, - "endIndex": end_index, - **kwargs, - } diff --git a/ai21/clients/common/segmentation_base.py b/ai21/clients/common/segmentation_base.py deleted file mode 100644 index 074ba8c8..00000000 --- a/ai21/clients/common/segmentation_base.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict - -from ai21.models import DocumentType, SegmentationResponse - - -class Segmentation(ABC): - _module_name = "segmentation" - - @abstractmethod - def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: - """ - - :param source: Raw input text, or URL of a web page. - :param source_type: The type of the source - either TEXT or URL. - :param kwargs: - :return: - """ - pass - - def _create_body(self, source: str, source_type: str, **kwargs) -> Dict[str, Any]: - return {"source": source, "sourceType": source_type, **kwargs} diff --git a/ai21/clients/common/summarize_base.py b/ai21/clients/common/summarize_base.py deleted file mode 100644 index 540c1300..00000000 --- a/ai21/clients/common/summarize_base.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Any, Dict, cast - -from ai21.models import SummarizeResponse, SummaryMethod -from ai21.models._pydantic_compatibility import _from_dict - - -class Summarize(ABC): - _module_name = "summarize" - - @abstractmethod - def create( - self, - source: str, - source_type: str, - *, - focus: Optional[str] = None, - summary_method: Optional[SummaryMethod] = None, - **kwargs, - ) -> SummarizeResponse: - """ - :param source: The input text, or URL of a web page to be summarized. - :param source_type: Either TEXT or URL - :param focus: Summaries focused on a topic of your choice. - :param summary_method: - :param kwargs: - :return: - """ - pass - - def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse: - return cast(_from_dict(obj=SummarizeResponse, obj_dict=json), SummarizeResponse) - - def _create_body( - self, - source: str, - source_type: str, - focus: Optional[str], - summary_method: Optional[str], - **kwargs, - ) -> Dict[str, Any]: - return { - "source": source, - "sourceType": source_type, - "focus": focus, - "summaryMethod": summary_method, - **kwargs, - } diff --git a/ai21/clients/common/summarize_by_segment_base.py b/ai21/clients/common/summarize_by_segment_base.py deleted file mode 100644 index 516d0ebe..00000000 --- a/ai21/clients/common/summarize_by_segment_base.py +++ /dev/null @@ -1,41 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Any, Dict - -from ai21.models import DocumentType, SummarizeBySegmentResponse - - -class SummarizeBySegment(ABC): - _module_name = "summarize-by-segment" - - @abstractmethod - def create( - self, - source: str, - source_type: DocumentType, - *, - focus: Optional[str] = None, - **kwargs, - ) -> SummarizeBySegmentResponse: - """ - - :param source: The input text, or URL of a web page to be summarized. - :param source_type: Either TEXT or URL - :param focus: Summaries focused on a topic of your choice. - :param kwargs: - :return: - """ - pass - - def _create_body( - self, - source: str, - source_type: str, - focus: Optional[str], - **kwargs, - ) -> Dict[str, Any]: - return { - "source": source, - "sourceType": source_type, - "focus": focus, - **kwargs, - } diff --git a/ai21/clients/sagemaker/__init__.py b/ai21/clients/sagemaker/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ai21/clients/sagemaker/ai21_sagemaker_client.py b/ai21/clients/sagemaker/ai21_sagemaker_client.py deleted file mode 100644 index 933f6fcc..00000000 --- a/ai21/clients/sagemaker/ai21_sagemaker_client.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Optional, Dict, Any - -import boto3 - -from ai21.ai21_env_config import AI21EnvConfig -from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer, AsyncSageMakerAnswer -from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion, AsyncSageMakerCompletion -from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC, AsyncSageMakerGEC -from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase, AsyncSageMakerParaphrase -from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize, AsyncSageMakerSummarize -from ai21.http_client.async_http_client import AsyncAI21HTTPClient -from ai21.http_client.http_client import AI21HTTPClient - - -class AI21SageMakerClient: - """ - :param endpoint_name: The name of the endpoint to use for the client. - :param region: The AWS region of the endpoint. - :param session: An optional boto3 session to use for the client. - """ - - def __init__( - self, - endpoint_name: str, - region: Optional[str] = None, - session: Optional["boto3.Session"] = None, - headers: Optional[Dict[str, Any]] = None, - timeout_sec: Optional[float] = None, - num_retries: Optional[int] = None, - http_client: Optional[AI21HTTPClient] = None, - **kwargs, - ): - region = region or AI21EnvConfig.aws_region - base_url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations" - self._http_client = http_client or AI21HTTPClient( - base_url=base_url, - headers=headers, - timeout_sec=timeout_sec, - num_retries=num_retries, - requires_api_key=False, - ) - - self.completion = SageMakerCompletion( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - self.paraphrase = SageMakerParaphrase( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - self.answer = SageMakerAnswer(base_url=base_url, region=region, client=self._http_client, aws_session=session) - self.gec = SageMakerGEC(base_url=base_url, region=region, client=self._http_client, aws_session=session) - self.summarize = SageMakerSummarize( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - - -class AsyncAI21SageMakerClient: - """ - :param endpoint_name: The name of the endpoint to use for the client. - :param region: The AWS region of the endpoint. - :param session: An optional boto3 session to use for the client. - """ - - def __init__( - self, - endpoint_name: str, - region: Optional[str] = None, - session: Optional["boto3.Session"] = None, - headers: Optional[Dict[str, Any]] = None, - timeout_sec: Optional[float] = None, - num_retries: Optional[int] = None, - http_client: Optional[AsyncAI21HTTPClient] = None, - **kwargs, - ): - region = region or AI21EnvConfig.aws_region - base_url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations" - self._http_client = http_client or AsyncAI21HTTPClient( - base_url=base_url, - headers=headers, - timeout_sec=timeout_sec, - num_retries=num_retries, - requires_api_key=False, - ) - - self.completion = AsyncSageMakerCompletion( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - self.paraphrase = AsyncSageMakerParaphrase( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - self.answer = AsyncSageMakerAnswer( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) - self.gec = AsyncSageMakerGEC(base_url=base_url, region=region, client=self._http_client, aws_session=session) - self.summarize = AsyncSageMakerSummarize( - base_url=base_url, region=region, client=self._http_client, aws_session=session - ) diff --git a/ai21/clients/sagemaker/constants.py b/ai21/clients/sagemaker/constants.py deleted file mode 100644 index 121e47b4..00000000 --- a/ai21/clients/sagemaker/constants.py +++ /dev/null @@ -1,9 +0,0 @@ -SAGEMAKER_MODEL_PACKAGE_NAMES = [ - "j2-light", - "j2-mid", - "j2-ultra", - "gec", - "contextual-answers", - "paraphrase", - "summarize", -] diff --git a/ai21/clients/sagemaker/resources/__init__.py b/ai21/clients/sagemaker/resources/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py deleted file mode 100644 index 4352a787..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ /dev/null @@ -1,32 +0,0 @@ -from ai21.clients.common.answer_base import Answer -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource -from ai21.models import AnswerResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class SageMakerAnswer(SageMakerResource, Answer): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - context: str, - question: str, - **kwargs, - ) -> AnswerResponse: - body = self._create_body(context=context, question=question) - response = self._post(body) - - return self._json_to_response(response.json()) - - -class AsyncSageMakerAnswer(AsyncSageMakerResource, Answer): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - context: str, - question: str, - **kwargs, - ) -> AnswerResponse: - body = self._create_body(context=context, question=question) - response = await self._post(body) - - return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py deleted file mode 100644 index 7e7ac651..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from typing import List, Dict, cast - -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource -from ai21.models import Penalty, CompletionsResponse -from ai21.models._pydantic_compatibility import _from_dict -from ai21.types import NotGiven, NOT_GIVEN -from ai21.utils.typing import remove_not_given - - -class SageMakerCompletion(SageMakerResource): - def create( - self, - prompt: str, - *, - max_tokens: int | NotGiven = NOT_GIVEN, - num_results: int | NotGiven = NOT_GIVEN, - min_tokens: int | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - top_k_return: int | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - frequency_penalty: Penalty | NotGiven = NOT_GIVEN, - presence_penalty: Penalty | NotGiven = NOT_GIVEN, - count_penalty: Penalty | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> CompletionsResponse: - """ - :param prompt: Text for model to complete - :param max_tokens: The maximum number of tokens to generate per result - :param num_results: Number of completions to sample and return. - :param min_tokens: The minimum number of tokens to generate per result. - :param temperature: A value controlling the "creativity" of the model's responses. - :param top_p: A value controlling the diversity of the model's responses. - :param top_k_return: The number of top-scoring tokens to consider for each generation step. - :param stop_sequences: Stops decoding if any of the strings is generated - :param frequency_penalty: A penalty applied to tokens that are frequently generated. - :param presence_penalty: A penalty applied to tokens that are already present in the prompt. - :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses - :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text - representations of the tokens and the floats are the biases themselves. A positive bias increases generation - probability for a given token and a negative bias decreases it. - :param kwargs: - :return: - """ - body = remove_not_given( - { - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, - "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, - "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, - "logitBias": logit_bias, - } - ) - - raw_response = self._post(body=body) - - return cast(_from_dict(obj=CompletionsResponse, obj_dict=raw_response.json()), CompletionsResponse) - - -class AsyncSageMakerCompletion(AsyncSageMakerResource): - async def create( - self, - prompt: str, - *, - max_tokens: int | NotGiven = NOT_GIVEN, - num_results: int | NotGiven = NOT_GIVEN, - min_tokens: int | NotGiven = NOT_GIVEN, - temperature: float | NotGiven = NOT_GIVEN, - top_p: float | NotGiven = NOT_GIVEN, - top_k_return: int | NotGiven = NOT_GIVEN, - stop_sequences: List[str] | NotGiven = NOT_GIVEN, - frequency_penalty: Penalty | NotGiven = NOT_GIVEN, - presence_penalty: Penalty | NotGiven = NOT_GIVEN, - count_penalty: Penalty | NotGiven = NOT_GIVEN, - logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> CompletionsResponse: - """ - :param prompt: Text for model to complete - :param max_tokens: The maximum number of tokens to generate per result - :param num_results: Number of completions to sample and return. - :param min_tokens: The minimum number of tokens to generate per result. - :param temperature: A value controlling the "creativity" of the model's responses. - :param top_p: A value controlling the diversity of the model's responses. - :param top_k_return: The number of top-scoring tokens to consider for each generation step. - :param stop_sequences: Stops decoding if any of the strings is generated - :param frequency_penalty: A penalty applied to tokens that are frequently generated. - :param presence_penalty: A penalty applied to tokens that are already present in the prompt. - :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses - :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text - representations of the tokens and the floats are the biases themselves. A positive bias increases generation - probability for a given token and a negative bias decreases it. - :param kwargs: - :return: - """ - body = remove_not_given( - { - "prompt": prompt, - "maxTokens": max_tokens, - "numResults": num_results, - "minTokens": min_tokens, - "temperature": temperature, - "topP": top_p, - "topKReturn": top_k_return, - "stopSequences": stop_sequences or [], - "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, - "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, - "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, - "logitBias": logit_bias, - } - ) - - raw_response = await self._post(body=body) - - return cast(_from_dict(obj=CompletionsResponse, obj_dict=raw_response.json()), CompletionsResponse) diff --git a/ai21/clients/sagemaker/resources/sagemaker_gec.py b/ai21/clients/sagemaker/resources/sagemaker_gec.py deleted file mode 100644 index 80a1f1f0..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_gec.py +++ /dev/null @@ -1,24 +0,0 @@ -from ai21.clients.common.gec_base import GEC -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource -from ai21.models import GECResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class SageMakerGEC(SageMakerResource, GEC): - @deprecated(V3_DEPRECATION_MESSAGE) - def create(self, text: str, **kwargs) -> GECResponse: - body = self._create_body(text=text) - - response = self._post(body) - - return self._json_to_response(response.json()) - - -class AsyncSageMakerGEC(AsyncSageMakerResource, GEC): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create(self, text: str, **kwargs) -> GECResponse: - body = self._create_body(text=text) - - response = await self._post(body) - - return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py deleted file mode 100644 index 0deb584e..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Optional - -from ai21.clients.common.paraphrase_base import Paraphrase -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource -from ai21.models import ParaphraseStyleType, ParaphraseResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class SageMakerParaphrase(SageMakerResource, Paraphrase): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - text: str, - *, - style: Optional[ParaphraseStyleType] = None, - start_index: Optional[int] = 0, - end_index: Optional[int] = None, - **kwargs, - ) -> ParaphraseResponse: - body = self._create_body( - text=text, - style=style, - start_index=start_index, - end_index=end_index, - ) - response = self._post(body=body) - - return self._json_to_response(response.json()) - - -class AsyncSageMakerParaphrase(AsyncSageMakerResource, Paraphrase): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - text: str, - *, - style: Optional[ParaphraseStyleType] = None, - start_index: Optional[int] = 0, - end_index: Optional[int] = None, - **kwargs, - ) -> ParaphraseResponse: - body = self._create_body( - text=text, - style=style, - start_index=start_index, - end_index=end_index, - ) - response = await self._post(body=body) - - return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_resource.py b/ai21/clients/sagemaker/resources/sagemaker_resource.py deleted file mode 100644 index 98982475..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_resource.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from typing import Any, Dict, Optional - -import boto3 -import httpx - -from ai21 import AI21APIError -from ai21.clients.aws.aws_authorization import AWSAuthorization -from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException, InternalDependencyException -from ai21.http_client.async_http_client import AsyncAI21HTTPClient -from ai21.http_client.http_client import AI21HTTPClient - - -def _handle_sagemaker_error(aws_error: AI21APIError) -> None: - status_code = aws_error.status_code - if status_code == 403: - raise AccessDenied(details=aws_error.details) - - if status_code == 404: - raise NotFound(details=aws_error.details) - - if status_code == 408: - raise APITimeoutError(details=aws_error.details) - - if status_code == 424: - raise ModelErrorException(details=aws_error.details) - - if status_code == 530: - raise InternalDependencyException(details=aws_error.details) - - raise aws_error - - -class SageMakerResource(ABC): - def __init__( - self, - region: str, - base_url: str, - client: AI21HTTPClient, - aws_session: Optional[boto3.Session] = None, - ): - self._client = client - self._aws_session = aws_session or boto3.Session(region_name=region) - self._aws_auth = AWSAuthorization(aws_session=self._aws_session) - self._base_url = base_url - - def _post( - self, - body: Dict[str, Any], - ) -> httpx.Response: - auth_headers = self._aws_auth.get_auth_headers( - service_name="sagemaker", url=self._base_url, method="POST", data=json.dumps(body) - ) - - try: - return self._client.execute_http_request(body=body, method="POST", extra_headers=auth_headers) - except AI21APIError as aws_error: - _handle_sagemaker_error(aws_error) - - -class AsyncSageMakerResource(ABC): - def __init__( - self, - base_url: str, - region: str, - client: AsyncAI21HTTPClient, - aws_session: Optional[boto3.Session] = None, - ): - self._client = client - self._aws_session = aws_session or boto3.Session(region_name=region) - self._base_url = base_url - self._aws_auth = AWSAuthorization(aws_session=self._aws_session) - - async def _post( - self, - body: Dict[str, Any], - ) -> httpx.Response: - auth_headers = self._aws_auth.get_auth_headers( - service_name="sagemaker", url=self._base_url, method="POST", data=json.dumps(body) - ) - - try: - return await self._client.execute_http_request(body=body, method="POST", extra_headers=auth_headers) - except AI21APIError as aws_error: - _handle_sagemaker_error(aws_error) diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py deleted file mode 100644 index 27afcdc9..00000000 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ /dev/null @@ -1,54 +0,0 @@ -from __future__ import annotations - -from typing import Optional - -from ai21.clients.common.summarize_base import Summarize -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource -from ai21.models import SummarizeResponse, SummaryMethod -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class SageMakerSummarize(SageMakerResource, Summarize): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - source: str, - source_type: str, - *, - focus: Optional[str] = None, - summary_method: Optional[SummaryMethod] = None, - **kwargs, - ) -> SummarizeResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - summary_method=summary_method, - ) - - response = self._post(body) - - return self._json_to_response(response.json()) - - -class AsyncSageMakerSummarize(AsyncSageMakerResource, Summarize): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - source: str, - source_type: str, - *, - focus: Optional[str] = None, - summary_method: Optional[SummaryMethod] = None, - **kwargs, - ) -> SummarizeResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - summary_method=summary_method, - ) - - response = await self._post(body) - - return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/sagemaker_session.py b/ai21/clients/sagemaker/sagemaker_session.py deleted file mode 100644 index 5473e870..00000000 --- a/ai21/clients/sagemaker/sagemaker_session.py +++ /dev/null @@ -1,59 +0,0 @@ -import json -import re - -import boto3 -from botocore.exceptions import ClientError - -from ai21.errors import BadRequest, ServiceUnavailable, AI21ServerError, AI21APIError -from ai21.http_client.base_http_client import handle_non_success_response -from ai21.logger import logger - -_ERROR_MSG_TEMPLATE = ( - r"Received client error \((.*?)\) from primary with message \"(.*?)\". " - r"See .* in account .* for more information." -) -_SAGEMAKER_RUNTIME_NAME = "sagemaker-runtime" - - -class SageMakerSession: - def __init__(self, session: boto3.Session, region: str, endpoint_name: str): - self._session = session if session else boto3.client(_SAGEMAKER_RUNTIME_NAME, region_name=region) - self._region = region - self._endpoint_name = endpoint_name - - def invoke_endpoint( - self, - input_json: str, - ): - try: - response = self._session.invoke_endpoint( - EndpointName=self._endpoint_name, - ContentType="application/json", - Accept="application/json", - Body=input_json, - ) - - return json.load(response["Body"]) - except ClientError as sm_client_error: - self._handle_client_error(sm_client_error) - except Exception as exception: - logger.error(f"Calling {self._endpoint_name} failed with Exception: {exception}") - raise exception - - def _handle_client_error(self, client_exception: "ClientError"): - error_response = client_exception.response - error_message = error_response.get("Error", {}).get("Message", "") - status_code = error_response.get("ResponseMetadata", {}).get("HTTPStatusCode", None) - # According to https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html#API_runtime_InvokeEndpoint_Errors - if status_code == 400: - raise BadRequest(details=error_message) - if status_code == 424: - error_message_template = re.compile(_ERROR_MSG_TEMPLATE) - model_status_code = int(error_message_template.search(error_message).group(1)) - model_error_message = error_message_template.search(error_message).group(2) - handle_non_success_response(model_status_code, model_error_message) - if status_code == 429 or status_code == 503: - raise ServiceUnavailable(details=error_message) - if status_code == 500: - raise AI21ServerError(details=error_message) - raise AI21APIError(status_code, details=error_message) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 0893cae7..34cdb491 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -1,28 +1,14 @@ -import warnings from typing import Optional, Any, Dict import httpx -from ai21_tokenizer import PreTrainedTokenizers from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.clients.studio.client_url_parser import create_client_url from ai21.clients.studio.resources.beta.beta import Beta -from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat -from ai21.clients.studio.resources.studio_completion import StudioCompletion -from ai21.clients.studio.resources.studio_custom_model import StudioCustomModel -from ai21.clients.studio.resources.studio_dataset import StudioDataset -from ai21.clients.studio.resources.studio_embed import StudioEmbed -from ai21.clients.studio.resources.studio_gec import StudioGEC -from ai21.clients.studio.resources.studio_improvements import StudioImprovements from ai21.clients.studio.resources.studio_library import StudioLibrary -from ai21.clients.studio.resources.studio_paraphrase import StudioParaphrase -from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation -from ai21.clients.studio.resources.studio_summarize import StudioSummarize -from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment from ai21.http_client.http_client import AI21HTTPClient from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer -from ai21.tokenizers.factory import get_tokenizer class AI21Client(AI21HTTPClient): @@ -54,26 +40,6 @@ def __init__( via=via, client=http_client, ) - self.completion = StudioCompletion(self) self.chat: StudioChat = StudioChat(self) - self.summarize = StudioSummarize(self) - self.embed = StudioEmbed(self) - self.gec = StudioGEC(self) - self.improvements = StudioImprovements(self) - self.paraphrase = StudioParaphrase(self) - self.summarize_by_segment = StudioSummarizeBySegment(self) - self.custom_model = StudioCustomModel(self) - self.dataset = StudioDataset(self) - self.answer = StudioAnswer(self) self.library = StudioLibrary(self) - self.segmentation = StudioSegmentation(self) self.beta = Beta(self) - - def count_tokens(self, text: str, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> int: - warnings.warn( - "Please use the global get_tokenizer() method directly instead of the AI21Client().count_tokens() method.", - DeprecationWarning, - ) - - tokenizer = get_tokenizer(tokenizer_name) - return tokenizer.count_tokens(text) diff --git a/ai21/clients/studio/async_ai21_client.py b/ai21/clients/studio/async_ai21_client.py index ecadceb6..4f7ed322 100644 --- a/ai21/clients/studio/async_ai21_client.py +++ b/ai21/clients/studio/async_ai21_client.py @@ -5,19 +5,8 @@ from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig from ai21.clients.studio.client_url_parser import create_client_url from ai21.clients.studio.resources.beta.async_beta import AsyncBeta -from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer from ai21.clients.studio.resources.studio_chat import AsyncStudioChat -from ai21.clients.studio.resources.studio_completion import AsyncStudioCompletion -from ai21.clients.studio.resources.studio_custom_model import AsyncStudioCustomModel -from ai21.clients.studio.resources.studio_dataset import AsyncStudioDataset -from ai21.clients.studio.resources.studio_embed import AsyncStudioEmbed -from ai21.clients.studio.resources.studio_gec import AsyncStudioGEC -from ai21.clients.studio.resources.studio_improvements import AsyncStudioImprovements from ai21.clients.studio.resources.studio_library import AsyncStudioLibrary -from ai21.clients.studio.resources.studio_paraphrase import AsyncStudioParaphrase -from ai21.clients.studio.resources.studio_segmentation import AsyncStudioSegmentation -from ai21.clients.studio.resources.studio_summarize import AsyncStudioSummarize -from ai21.clients.studio.resources.studio_summarize_by_segment import AsyncStudioSummarizeBySegment from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -50,17 +39,6 @@ def __init__( client=http_client, ) - self.completion = AsyncStudioCompletion(self) self.chat: AsyncStudioChat = AsyncStudioChat(self) - self.summarize = AsyncStudioSummarize(self) - self.embed = AsyncStudioEmbed(self) - self.gec = AsyncStudioGEC(self) - self.improvements = AsyncStudioImprovements(self) - self.paraphrase = AsyncStudioParaphrase(self) - self.summarize_by_segment = AsyncStudioSummarizeBySegment(self) - self.custom_model = AsyncStudioCustomModel(self) - self.dataset = AsyncStudioDataset(self) - self.answer = AsyncStudioAnswer(self) self.library = AsyncStudioLibrary(self) - self.segmentation = AsyncStudioSegmentation(self) self.beta = AsyncBeta(self) diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py deleted file mode 100644 index 948775b0..00000000 --- a/ai21/clients/studio/resources/studio_answer.py +++ /dev/null @@ -1,30 +0,0 @@ -from ai21.clients.common.answer_base import Answer -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import AnswerResponse -from ai21.version_utils import deprecated, V3_DEPRECATION_MESSAGE - - -class StudioAnswer(StudioResource, Answer): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - context: str, - question: str, - **kwargs, - ) -> AnswerResponse: - body = self._create_body(context=context, question=question, **kwargs) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=AnswerResponse) - - -class AsyncStudioAnswer(AsyncStudioResource, Answer): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - context: str, - question: str, - **kwargs, - ) -> AnswerResponse: - body = self._create_body(context=context, question=question, **kwargs) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=AnswerResponse) diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index acaf116d..db1930b5 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -20,7 +20,6 @@ def create( temperature: float | NotGiven = NOT_GIVEN, top_p: float | NotGiven = NOT_GIVEN, top_k_return: int | NotGiven = NOT_GIVEN, - custom_model: str | NotGiven = NOT_GIVEN, stop_sequences: List[str] | NotGiven = NOT_GIVEN, frequency_penalty: Penalty | NotGiven = NOT_GIVEN, presence_penalty: Penalty | NotGiven = NOT_GIVEN, @@ -30,7 +29,7 @@ def create( **kwargs, ) -> CompletionsResponse: model = self._get_model(model=model, model_id=kwargs.pop("model_id", None)) - path = self._get_completion_path(model=model, custom_model=custom_model) + path = self._get_completion_path(model=model) body = self._create_body( model=model, prompt=prompt, @@ -40,7 +39,6 @@ def create( temperature=temperature, top_p=top_p, top_k_return=top_k_return, - custom_model=custom_model, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, @@ -64,7 +62,6 @@ async def create( temperature: float | NotGiven = NOT_GIVEN, top_p: float | NotGiven = NOT_GIVEN, top_k_return: int | NotGiven = NOT_GIVEN, - custom_model: str | NotGiven = NOT_GIVEN, stop_sequences: List[str] | NotGiven = NOT_GIVEN, frequency_penalty: Penalty | NotGiven = NOT_GIVEN, presence_penalty: Penalty | NotGiven = NOT_GIVEN, @@ -74,7 +71,7 @@ async def create( **kwargs, ) -> CompletionsResponse: model = self._get_model(model=model, model_id=kwargs.pop("model_id", None)) - path = self._get_completion_path(model=model, custom_model=custom_model) + path = self._get_completion_path(model=model) body = self._create_body( model=model, prompt=prompt, @@ -84,7 +81,6 @@ async def create( temperature=temperature, top_p=top_p, top_k_return=top_k_return, - custom_model=custom_model, stop_sequences=stop_sequences, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py deleted file mode 100644 index 76ce63b7..00000000 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import List, Optional - -from ai21.clients.common.custom_model_base import CustomModel -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import CustomBaseModelResponse -from ai21.version_utils import deprecated, V3_DEPRECATION_MESSAGE - - -class StudioCustomModel(StudioResource, CustomModel): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - dataset_id: str, - model_name: str, - model_type: str, - *, - learning_rate: Optional[float] = None, - num_epochs: Optional[int] = None, - **kwargs, - ) -> None: - body = self._create_body( - dataset_id=dataset_id, - model_name=model_name, - model_type=model_type, - learning_rate=learning_rate, - num_epochs=num_epochs, - **kwargs, - ) - self._post(path=f"/{self._module_name}", body=body, response_cls=None) - - @deprecated(V3_DEPRECATION_MESSAGE) - def list(self) -> List[CustomBaseModelResponse]: - return self._get(path=f"/{self._module_name}", response_cls=List[CustomBaseModelResponse]) - - @deprecated(V3_DEPRECATION_MESSAGE) - def get(self, resource_id: str) -> CustomBaseModelResponse: - return self._get(path=f"/{self._module_name}/{resource_id}", response_cls=CustomBaseModelResponse) - - -class AsyncStudioCustomModel(AsyncStudioResource, CustomModel): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - dataset_id: str, - model_name: str, - model_type: str, - *, - learning_rate: Optional[float] = None, - num_epochs: Optional[int] = None, - **kwargs, - ) -> None: - body = self._create_body( - dataset_id=dataset_id, - model_name=model_name, - model_type=model_type, - learning_rate=learning_rate, - num_epochs=num_epochs, - **kwargs, - ) - await self._post(path=f"/{self._module_name}", body=body, response_cls=None) - - @deprecated(V3_DEPRECATION_MESSAGE) - async def list(self) -> List[CustomBaseModelResponse]: - return await self._get(path=f"/{self._module_name}", response_cls=List[CustomBaseModelResponse]) - - @deprecated(V3_DEPRECATION_MESSAGE) - async def get(self, resource_id: str) -> CustomBaseModelResponse: - return await self._get(path=f"/{self._module_name}/{resource_id}", response_cls=CustomBaseModelResponse) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py deleted file mode 100644 index 0a9df9ca..00000000 --- a/ai21/clients/studio/resources/studio_dataset.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Optional, List - -from ai21.clients.common.dataset_base import Dataset -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import DatasetResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioDataset(StudioResource, Dataset): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - file_path: str, - dataset_name: str, - *, - selected_columns: Optional[str] = None, - approve_whitespace_correction: Optional[bool] = None, - delete_long_rows: Optional[bool] = None, - split_ratio: Optional[float] = None, - **kwargs, - ): - files = {"dataset_file": open(file_path, "rb")} - body = self._create_body( - dataset_name=dataset_name, - selected_columns=selected_columns, - approve_whitespace_correction=approve_whitespace_correction, - delete_long_rows=delete_long_rows, - split_ratio=split_ratio, - **kwargs, - ) - return self._post( - path=f"/{self._module_name}", - body=body, - files=files, - ) - - @deprecated(V3_DEPRECATION_MESSAGE) - def list(self) -> List[DatasetResponse]: - return self._get(path=f"/{self._module_name}", response_cls=List[DatasetResponse]) - - @deprecated(V3_DEPRECATION_MESSAGE) - def get(self, dataset_pid: str) -> DatasetResponse: - return self._get(path=f"/{self._module_name}/{dataset_pid}", response_cls=DatasetResponse) - - -class AsyncStudioDataset(AsyncStudioResource, Dataset): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - file_path: str, - dataset_name: str, - *, - selected_columns: Optional[str] = None, - approve_whitespace_correction: Optional[bool] = None, - delete_long_rows: Optional[bool] = None, - split_ratio: Optional[float] = None, - **kwargs, - ): - files = {"dataset_file": open(file_path, "rb")} - body = self._create_body( - dataset_name=dataset_name, - selected_columns=selected_columns, - approve_whitespace_correction=approve_whitespace_correction, - delete_long_rows=delete_long_rows, - split_ratio=split_ratio, - **kwargs, - ) - - return await self._post( - path=f"/{self._module_name}", - body=body, - files=files, - ) - - @deprecated(V3_DEPRECATION_MESSAGE) - async def list(self) -> List[DatasetResponse]: - return await self._get(path=f"/{self._module_name}", response_cls=List[DatasetResponse]) - - @deprecated(V3_DEPRECATION_MESSAGE) - async def get(self, dataset_pid: str) -> DatasetResponse: - return await self._get(path=f"/{self._module_name}/{dataset_pid}", response_cls=DatasetResponse) diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py deleted file mode 100644 index b6ffa100..00000000 --- a/ai21/clients/studio/resources/studio_embed.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import List, Optional - -from ai21.clients.common.embed_base import Embed -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import EmbedType, EmbedResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioEmbed(StudioResource, Embed): - @deprecated(V3_DEPRECATION_MESSAGE) - def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: - body = self._create_body(texts=texts, type=type, **kwargs) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=EmbedResponse) - - -class AsyncStudioEmbed(AsyncStudioResource, Embed): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: - body = self._create_body(texts=texts, type=type, **kwargs) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=EmbedResponse) diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py deleted file mode 100644 index f2486d58..00000000 --- a/ai21/clients/studio/resources/studio_gec.py +++ /dev/null @@ -1,20 +0,0 @@ -from ai21.clients.common.gec_base import GEC -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import GECResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioGEC(StudioResource, GEC): - @deprecated(V3_DEPRECATION_MESSAGE) - def create(self, text: str, **kwargs) -> GECResponse: - body = self._create_body(text=text, **kwargs) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=GECResponse) - - -class AsyncStudioGEC(AsyncStudioResource, GEC): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create(self, text: str, **kwargs) -> GECResponse: - body = self._create_body(text=text, **kwargs) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=GECResponse) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py deleted file mode 100644 index f5b8adb4..00000000 --- a/ai21/clients/studio/resources/studio_improvements.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import List - -from ai21.clients.common.improvements_base import Improvements -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import ImprovementType, ImprovementsResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioImprovements(StudioResource, Improvements): - @deprecated(V3_DEPRECATION_MESSAGE) - def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: - self._validate_types(types) - - body = self._create_body(text=text, types=types, **kwargs) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=ImprovementsResponse) - - -class AsyncStudioImprovements(AsyncStudioResource, Improvements): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse: - self._validate_types(types) - - body = self._create_body(text=text, types=types, **kwargs) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=ImprovementsResponse) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index e29311ea..26f3bf9f 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -5,7 +5,7 @@ from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient -from ai21.models import FileResponse, LibraryAnswerResponse, LibrarySearchResponse +from ai21.models import FileResponse from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -16,8 +16,6 @@ class StudioLibrary(StudioResource): def __init__(self, client: AI21HTTPClient): super().__init__(client) self.files = LibraryFiles(client) - self.search = LibrarySearch(client) - self.answer = LibraryAnswer(client) class LibraryFiles(StudioResource): @@ -74,64 +72,12 @@ def delete(self, file_id: str) -> None: self._delete(path=f"/{self._module_name}/{file_id}") -class LibrarySearch(StudioResource): - _module_name = "library/search" - - def create( - self, - query: str, - *, - path: Optional[str] | NotGiven = NOT_GIVEN, - field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, - max_segments: Optional[int] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> LibrarySearchResponse: - body = remove_not_given( - { - "query": query, - "path": path, - "fieldIds": field_ids, - "maxSegments": max_segments, - **kwargs, - } - ) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=LibrarySearchResponse) - - -class LibraryAnswer(StudioResource): - _module_name = "library/answer" - - def create( - self, - question: str, - *, - path: Optional[str] | NotGiven = NOT_GIVEN, - field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, - labels: Optional[List[str]] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> LibraryAnswerResponse: - body = remove_not_given( - { - "question": question, - "path": path, - "fieldIds": field_ids, - "labels": labels, - **kwargs, - } - ) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=LibraryAnswerResponse) - - class AsyncStudioLibrary(AsyncStudioResource): _module_name = "library/files" def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) self.files = AsyncLibraryFiles(client) - self.search = AsyncLibrarySearch(client) - self.answer = AsyncLibraryAnswer(client) class AsyncLibraryFiles(AsyncStudioResource): @@ -186,53 +132,3 @@ async def update( async def delete(self, file_id: str) -> None: await self._delete(path=f"/{self._module_name}/{file_id}") - - -class AsyncLibrarySearch(AsyncStudioResource): - _module_name = "library/search" - - async def create( - self, - query: str, - *, - path: Optional[str] | NotGiven = NOT_GIVEN, - field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, - max_segments: Optional[int] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> LibrarySearchResponse: - body = remove_not_given( - { - "query": query, - "path": path, - "fieldIds": field_ids, - "maxSegments": max_segments, - **kwargs, - } - ) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=LibrarySearchResponse) - - -class AsyncLibraryAnswer(AsyncStudioResource): - _module_name = "library/answer" - - async def create( - self, - question: str, - *, - path: Optional[str] | NotGiven = NOT_GIVEN, - field_ids: Optional[List[str]] | NotGiven = NOT_GIVEN, - labels: Optional[List[str]] | NotGiven = NOT_GIVEN, - **kwargs, - ) -> LibraryAnswerResponse: - body = remove_not_given( - { - "question": question, - "path": path, - "fieldIds": field_ids, - "labels": labels, - **kwargs, - } - ) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=LibraryAnswerResponse) diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py deleted file mode 100644 index 194e378f..00000000 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Optional - -from ai21.clients.common.paraphrase_base import Paraphrase -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import ParaphraseStyleType, ParaphraseResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioParaphrase(StudioResource, Paraphrase): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - text: str, - *, - style: Optional[ParaphraseStyleType] = None, - start_index: Optional[int] = None, - end_index: Optional[int] = None, - **kwargs, - ) -> ParaphraseResponse: - body = self._create_body( - text=text, - style=style, - start_index=start_index, - end_index=end_index, - **kwargs, - ) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=ParaphraseResponse) - - -class AsyncStudioParaphrase(AsyncStudioResource, Paraphrase): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - text: str, - *, - style: Optional[ParaphraseStyleType] = None, - start_index: Optional[int] = None, - end_index: Optional[int] = None, - **kwargs, - ) -> ParaphraseResponse: - body = self._create_body( - text=text, - style=style, - start_index=start_index, - end_index=end_index, - **kwargs, - ) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=ParaphraseResponse) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py deleted file mode 100644 index 7bcd0e6f..00000000 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ /dev/null @@ -1,20 +0,0 @@ -from ai21.clients.common.segmentation_base import Segmentation -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import DocumentType, SegmentationResponse -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioSegmentation(StudioResource, Segmentation): - @deprecated(V3_DEPRECATION_MESSAGE) - def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: - body = self._create_body(source=source, source_type=source_type.value, **kwargs) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=SegmentationResponse) - - -class AsyncStudioSegmentation(AsyncStudioResource, Segmentation): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: - body = self._create_body(source=source, source_type=source_type.value, **kwargs) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=SegmentationResponse) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py deleted file mode 100644 index b8a86cff..00000000 --- a/ai21/clients/studio/resources/studio_summarize.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import Optional - -from ai21.clients.common.summarize_base import Summarize -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import SummarizeResponse, SummaryMethod -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioSummarize(StudioResource, Summarize): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - source: str, - source_type: str, - *, - focus: Optional[str] = None, - summary_method: Optional[SummaryMethod] = None, - **kwargs, - ) -> SummarizeResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - summary_method=summary_method, - **kwargs, - ) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeResponse) - - -class AsyncStudioSummarize(AsyncStudioResource, Summarize): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - source: str, - source_type: str, - *, - focus: Optional[str] = None, - summary_method: Optional[SummaryMethod] = None, - **kwargs, - ) -> SummarizeResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - summary_method=summary_method, - **kwargs, - ) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeResponse) diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py deleted file mode 100644 index 0c47203a..00000000 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Optional - -from ai21.clients.common.summarize_by_segment_base import SummarizeBySegment -from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models import SummarizeBySegmentResponse, DocumentType -from ai21.version_utils import V3_DEPRECATION_MESSAGE, deprecated - - -class StudioSummarizeBySegment(StudioResource, SummarizeBySegment): - @deprecated(V3_DEPRECATION_MESSAGE) - def create( - self, - source: str, - source_type: DocumentType, - *, - focus: Optional[str] = None, - **kwargs, - ) -> SummarizeBySegmentResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - **kwargs, - ) - - return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeBySegmentResponse) - - -class AsyncStudioSummarizeBySegment(AsyncStudioResource, SummarizeBySegment): - @deprecated(V3_DEPRECATION_MESSAGE) - async def create( - self, - source: str, - source_type: DocumentType, - *, - focus: Optional[str] = None, - **kwargs, - ) -> SummarizeBySegmentResponse: - body = self._create_body( - source=source, - source_type=source_type, - focus=focus, - **kwargs, - ) - - return await self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeBySegmentResponse) diff --git a/ai21/models/__init__.py b/ai21/models/__init__.py index c10d965a..7d45c031 100644 --- a/ai21/models/__init__.py +++ b/ai21/models/__init__.py @@ -1,11 +1,7 @@ from ai21.models.chat.role_type import RoleType from ai21.models.chat_message import ChatMessage from ai21.models.document_type import DocumentType -from ai21.models.embed_type import EmbedType -from ai21.models.improvement_type import ImprovementType -from ai21.models.paraphrase_style_type import ParaphraseStyleType from ai21.models.penalty import Penalty -from ai21.models.responses.answer_response import AnswerResponse from ai21.models.responses.chat_response import ChatResponse, ChatOutput, FinishReason from ai21.models.responses.completion_response import ( CompletionsResponse, @@ -15,30 +11,13 @@ Prompt, ) from ai21.models.responses.conversational_rag_response import ConversationalRagResponse, ConversationalRagSource -from ai21.models.responses.custom_model_response import CustomBaseModelResponse, BaseModelMetadata -from ai21.models.responses.dataset_response import DatasetResponse -from ai21.models.responses.embed_response import EmbedResponse, EmbedResult from ai21.models.responses.file_response import FileResponse -from ai21.models.responses.gec_response import GECResponse, Correction, CorrectionType -from ai21.models.responses.improvement_response import ImprovementsResponse, Improvement -from ai21.models.responses.library_answer_response import LibraryAnswerResponse, SourceDocument -from ai21.models.responses.library_search_response import LibrarySearchResponse, LibrarySearchResult -from ai21.models.responses.paraphrase_response import ParaphraseResponse, Suggestion -from ai21.models.responses.segmentation_response import SegmentationResponse -from ai21.models.responses.summarize_by_segment_response import SummarizeBySegmentResponse, SegmentSummary, Highlight -from ai21.models.responses.summarize_response import SummarizeResponse -from ai21.models.summary_method import SummaryMethod __all__ = [ "ChatMessage", "RoleType", "Penalty", - "EmbedType", - "ImprovementType", - "ParaphraseStyleType", "DocumentType", - "SummaryMethod", - "AnswerResponse", "ChatResponse", "ChatOutput", "FinishReason", @@ -47,28 +26,7 @@ "CompletionFinishReason", "CompletionData", "Prompt", - "CustomBaseModelResponse", - "BaseModelMetadata", - "DatasetResponse", - "EmbedResponse", - "EmbedResult", "FileResponse", - "GECResponse", - "Correction", - "CorrectionType", - "ImprovementsResponse", - "Improvement", - "LibraryAnswerResponse", - "SourceDocument", - "LibrarySearchResponse", - "LibrarySearchResult", - "ParaphraseResponse", - "Suggestion", - "SegmentationResponse", - "SegmentSummary", - "Highlight", - "SummarizeBySegmentResponse", - "SummarizeResponse", "ConversationalRagResponse", "ConversationalRagSource", ] diff --git a/ai21/models/embed_type.py b/ai21/models/embed_type.py deleted file mode 100644 index d1268a86..00000000 --- a/ai21/models/embed_type.py +++ /dev/null @@ -1,6 +0,0 @@ -from enum import Enum - - -class EmbedType(str, Enum): - QUERY = "query" - SEGMENT = "segment" diff --git a/ai21/models/improvement_type.py b/ai21/models/improvement_type.py deleted file mode 100644 index 0774c4b0..00000000 --- a/ai21/models/improvement_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - - -class ImprovementType(str, Enum): - FLUENCY = "fluency" - VOCABULARY_SPECIFICITY = "vocabulary/specificity" - VOCABULARY_VARIETY = "vocabulary/variety" - CLARITY_SHORT_SENTENCES = "clarity/short-sentences" - CLARITY_CONCISENESS = "clarity/conciseness" diff --git a/ai21/models/paraphrase_style_type.py b/ai21/models/paraphrase_style_type.py deleted file mode 100644 index b7d7bd54..00000000 --- a/ai21/models/paraphrase_style_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - - -class ParaphraseStyleType(str, Enum): - LONG = "long" - SHORT = "short" - FORMAL = "formal" - CASUAL = "casual" - GENERAL = "general" diff --git a/ai21/models/responses/answer_response.py b/ai21/models/responses/answer_response.py deleted file mode 100644 index 9e10d7d0..00000000 --- a/ai21/models/responses/answer_response.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Optional - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class AnswerResponse(AI21BaseModel): - id: str - answer_in_context: Optional[bool] = Field(default=None, alias="answerInContext") - answer: Optional[str] = None diff --git a/ai21/models/responses/custom_model_response.py b/ai21/models/responses/custom_model_response.py deleted file mode 100644 index 041c4c40..00000000 --- a/ai21/models/responses/custom_model_response.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class BaseModelMetadata(AI21BaseModel): - learning_rate: float = Field(alias="learningRate") - num_epochs: int = Field(alias="numEpochs") - default_epoch: int = Field(alias="defaultEpoch") - cost: Optional[float] = None - validation_loss: Optional[float] = Field(None, alias="validationLoss") - eval_loss: Optional[float] = Field(None, alias="evalLoss") - - -class CustomBaseModelResponse(AI21BaseModel): - id: str - name: str - tier: str - model_type: str = Field(alias="modelType") - custom_model_type: Optional[str] = Field(alias="customModelType") - status: str - model_metadata: BaseModelMetadata = Field(alias="modelMetadata") - dataset_id: int = Field(alias="datasetId") - dataset_name: str = Field(alias="datasetName") - creation_date: str = Field(alias="creationDate") - current_epoch: Optional[int] = Field(alias="currentEpoch") - size: Optional[str] = None diff --git a/ai21/models/responses/dataset_response.py b/ai21/models/responses/dataset_response.py deleted file mode 100644 index d848de4e..00000000 --- a/ai21/models/responses/dataset_response.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class DatasetResponse(AI21BaseModel): - id: str - dataset_name: str = Field(alias="datasetName") - size_bytes: int = Field(alias="sizeBytes") - creation_date: datetime = Field(alias="creationDate") - num_examples: int = Field(alias="numExamples") - validation_num_examples: int = Field(alias="validationNumExamples") - train_num_examples: int = Field(alias="trainNumExamples") - num_models_used: int = Field(alias="numModelsUsed") diff --git a/ai21/models/responses/embed_response.py b/ai21/models/responses/embed_response.py deleted file mode 100644 index 6cb10921..00000000 --- a/ai21/models/responses/embed_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import List - -from ai21.models.ai21_base_model import AI21BaseModel - - -class EmbedResult(AI21BaseModel): - embedding: List[float] - - def __init__(self, embedding: List[float]): - super().__init__(embedding=embedding) - - -class EmbedResponse(AI21BaseModel): - id: str - results: List[EmbedResult] diff --git a/ai21/models/responses/gec_response.py b/ai21/models/responses/gec_response.py deleted file mode 100644 index 5cfb48c4..00000000 --- a/ai21/models/responses/gec_response.py +++ /dev/null @@ -1,28 +0,0 @@ -from enum import Enum -from typing import List - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class CorrectionType(str, Enum): - GRAMMAR = "Grammar" - MISSING_WORD = "Missing Word" - PUNCTUATION = "Punctuation" - SPELLING = "Spelling" - WORD_REPETITION = "Word Repetition" - WRONG_WORD = "Wrong Word" - - -class Correction(AI21BaseModel): - suggestion: str - start_index: int = Field(alias="startIndex") - end_index: int = Field(alias="endIndex") - original_text: str = Field(alias="originalText") - correction_type: CorrectionType = Field(alias="correctionType") - - -class GECResponse(AI21BaseModel): - id: str - corrections: List[Correction] diff --git a/ai21/models/responses/improvement_response.py b/ai21/models/responses/improvement_response.py deleted file mode 100644 index c9eea2f1..00000000 --- a/ai21/models/responses/improvement_response.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import List - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class Improvement(AI21BaseModel): - suggestions: List[str] - start_index: int = Field(alias="startIndex") - end_index: int = Field(alias="endIndex") - original_text: str = Field(alias="originalText") - improvement_type: str = Field(alias="improvementType") - - -class ImprovementsResponse(AI21BaseModel): - id: str - improvements: List[Improvement] diff --git a/ai21/models/responses/library_answer_response.py b/ai21/models/responses/library_answer_response.py deleted file mode 100644 index 11f0e683..00000000 --- a/ai21/models/responses/library_answer_response.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import List, Optional - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class SourceDocument(AI21BaseModel): - file_id: str = Field(alias="fileId") - name: str - highlights: List[str] - public_url: Optional[str] = Field(default=None, alias="publicUrl") - - -class LibraryAnswerResponse(AI21BaseModel): - id: str - answer_in_context: bool = Field(alias="answerInContext") - answer: Optional[str] = None - sources: Optional[List[SourceDocument]] = None diff --git a/ai21/models/responses/library_search_response.py b/ai21/models/responses/library_search_response.py deleted file mode 100644 index bdfafd6f..00000000 --- a/ai21/models/responses/library_search_response.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Optional, List - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class LibrarySearchResult(AI21BaseModel): - text: str - file_id: str = Field(alias="fileId") - file_name: str = Field(alias="fileName") - score: float - order: Optional[int] = None - public_url: Optional[str] = Field(default=None, alias="publicUrl") - labels: Optional[List[str]] = None - - -class LibrarySearchResponse(AI21BaseModel): - id: str - results: List[LibrarySearchResult] diff --git a/ai21/models/responses/paraphrase_response.py b/ai21/models/responses/paraphrase_response.py deleted file mode 100644 index 25e01da9..00000000 --- a/ai21/models/responses/paraphrase_response.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import List - -from ai21.models.ai21_base_model import AI21BaseModel - - -class Suggestion(AI21BaseModel): - text: str - - -class ParaphraseResponse(AI21BaseModel): - id: str - suggestions: List[Suggestion] diff --git a/ai21/models/responses/segmentation_response.py b/ai21/models/responses/segmentation_response.py deleted file mode 100644 index 8ea8d190..00000000 --- a/ai21/models/responses/segmentation_response.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import List - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class Segment(AI21BaseModel): - segment_text: str = Field(alias="segmentText") - segment_type: str = Field(alias="segmentType") - - -class SegmentationResponse(AI21BaseModel): - id: str - segments: List[Segment] diff --git a/ai21/models/responses/summarize_by_segment_response.py b/ai21/models/responses/summarize_by_segment_response.py deleted file mode 100644 index 2766e95f..00000000 --- a/ai21/models/responses/summarize_by_segment_response.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import List, Optional - -from pydantic import Field - -from ai21.models.ai21_base_model import AI21BaseModel - - -class Highlight(AI21BaseModel): - text: str - start_index: int = Field(alias="startIndex") - end_index: int = Field(alias="endIndex") - - -class SegmentSummary(AI21BaseModel): - summary: Optional[str] = None - segment_text: Optional[str] = Field(default=None, alias="segmentText") - segment_html: Optional[str] = Field(default=None, alias="segmentHtml") - segment_type: str = Field(alias="segmentType") - has_summary: bool = Field(alias="hasSummary") - highlights: List[Highlight] - - -class SummarizeBySegmentResponse(AI21BaseModel): - id: str - segments: List[SegmentSummary] diff --git a/ai21/models/responses/summarize_response.py b/ai21/models/responses/summarize_response.py deleted file mode 100644 index 4d7cc0d8..00000000 --- a/ai21/models/responses/summarize_response.py +++ /dev/null @@ -1,6 +0,0 @@ -from ai21.models.ai21_base_model import AI21BaseModel - - -class SummarizeResponse(AI21BaseModel): - id: str - summary: str diff --git a/ai21/models/summary_method.py b/ai21/models/summary_method.py deleted file mode 100644 index b4b05554..00000000 --- a/ai21/models/summary_method.py +++ /dev/null @@ -1,7 +0,0 @@ -from enum import Enum - - -class SummaryMethod(str, Enum): - SEGMENTS = "segments" - GUIDED = "guided" - FULL_DOCUMENT = "fullDocument" diff --git a/ai21/services/__init__.py b/ai21/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py deleted file mode 100644 index 5fdbc4a3..00000000 --- a/ai21/services/sagemaker.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import List - -from ai21 import AI21EnvConfig -from ai21.clients.sagemaker.constants import ( - SAGEMAKER_MODEL_PACKAGE_NAMES, -) -from ai21.errors import ModelPackageDoesntExistError -from ai21.http_client.http_client import AI21HTTPClient - -_JUMPSTART_ENDPOINT = "jumpstart" -_LIST_VERSIONS_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/list_versions" -_GET_ARN_ENDPOINT = f"{_JUMPSTART_ENDPOINT}/get_model_version_arn" - -LATEST_VERSION_STR = "latest" - - -class SageMaker: - @classmethod - def get_model_package_arn(cls, model_name: str, region: str, version: str = LATEST_VERSION_STR) -> str: - _assert_model_package_exists(model_name=model_name, region=region) - - client = cls._create_ai21_http_client(path=_GET_ARN_ENDPOINT) - - response = client.execute_http_request( - method="POST", - body={ - "modelName": model_name, - "region": region, - "version": version, - }, - ) - - arn = response.json()["arn"] - - if not arn: - raise ModelPackageDoesntExistError(model_name=model_name, region=region, version=version) - - return arn - - @classmethod - def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: - _assert_model_package_exists(model_name=model_name, region=region) - - client = cls._create_ai21_http_client(path=_LIST_VERSIONS_ENDPOINT) - - response = client.execute_http_request( - method="POST", - body={ - "modelName": model_name, - "region": region, - }, - ) - - return response.json()["versions"] - - @classmethod - def _create_ai21_http_client(cls, path: str) -> AI21HTTPClient: - return AI21HTTPClient( - api_key=AI21EnvConfig.api_key, - base_url=f"{AI21EnvConfig.api_host}/studio/v1/{path}", - requires_api_key=False, - timeout_sec=AI21EnvConfig.timeout_sec, - num_retries=AI21EnvConfig.num_retries, - ) - - -def _assert_model_package_exists(model_name, region): - if model_name not in SAGEMAKER_MODEL_PACKAGE_NAMES: - raise ModelPackageDoesntExistError(model_name=model_name, region=region) diff --git a/examples/sagemaker/answer.py b/examples/sagemaker/answer.py deleted file mode 100644 index fb177cc8..00000000 --- a/examples/sagemaker/answer.py +++ /dev/null @@ -1,14 +0,0 @@ -from ai21 import AI21SageMakerClient - -client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") - -response = client.answer.create( - context="Holland is a geographical region[2] and former province on the western coast" - " of the Netherlands.[2] From the " - "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " - "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " - "economic power, dominating the other provinces of the newly independent Dutch Republic.", - question="When did Holland become an economic power?", -) - -print(response.answer) diff --git a/examples/sagemaker/async_answer.py b/examples/sagemaker/async_answer.py deleted file mode 100644 index ceeee832..00000000 --- a/examples/sagemaker/async_answer.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21SageMakerClient - -client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -async def main(): - response = await client.answer.create( - context="Holland is a geographical region[2] and former province on the western coast" - " of the Netherlands.[2] From the 10th to the 16th century, Holland proper was a unified political region " - "within the Holy Roman Empire as a county ruled by the counts of Holland. By the 17th century, the province " - "of Holland had risen to become a maritime and economic power, dominating the other provinces of the newly " - "independent Dutch Republic.", - question="When did Holland become an economic power?", - ) - - print(response.answer) - - -asyncio.run(main()) diff --git a/examples/sagemaker/async_completion.py b/examples/sagemaker/async_completion.py deleted file mode 100644 index a9d445b5..00000000 --- a/examples/sagemaker/async_completion.py +++ /dev/null @@ -1,47 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21SageMakerClient - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -async def main(): - response = await client.completion.create(prompt=prompt, max_tokens=2) - - print(response.completions[0].data.text) - - print(response.prompt.tokens[0]["textRange"]["start"]) - - -asyncio.run(main()) diff --git a/examples/sagemaker/async_gec.py b/examples/sagemaker/async_gec.py deleted file mode 100644 index 526498f3..00000000 --- a/examples/sagemaker/async_gec.py +++ /dev/null @@ -1,14 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21SageMakerClient - -client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -async def main(): - response = client.gec.create(text="roc and rolle") - - print(response.corrections[0].suggestion) - - -asyncio.run(main()) diff --git a/examples/sagemaker/async_paraphrase.py b/examples/sagemaker/async_paraphrase.py deleted file mode 100644 index 4bff2c49..00000000 --- a/examples/sagemaker/async_paraphrase.py +++ /dev/null @@ -1,17 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21SageMakerClient - -client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -async def main(): - response = await client.paraphrase.create( - text="What's the difference between Scottish Fold and British?", - style="formal", - ) - - print(response.suggestions[0].text) - - -asyncio.run(main()) diff --git a/examples/sagemaker/async_summarization.py b/examples/sagemaker/async_summarization.py deleted file mode 100644 index 44cb4bff..00000000 --- a/examples/sagemaker/async_summarization.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21SageMakerClient - -client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -async def main(): - response = await client.summarize.create( - source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2]" - " From the 10th to the 16th century, " - "Holland proper was a unified political region within the Holy Roman Empire as a" - " county ruled by the counts of Holland. By the 17th century, " - "the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch " - "Republic.", - source_type="TEXT", - ) - - print(response.summary) - - -asyncio.run(main()) diff --git a/examples/sagemaker/completion.py b/examples/sagemaker/completion.py deleted file mode 100644 index 2baa8634..00000000 --- a/examples/sagemaker/completion.py +++ /dev/null @@ -1,39 +0,0 @@ -from ai21 import AI21SageMakerClient - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") -response = client.completion.create(prompt=prompt, max_tokens=2) - -print(response) - -print(response.prompt.tokens[0]["textRange"]["start"]) diff --git a/examples/sagemaker/gec.py b/examples/sagemaker/gec.py deleted file mode 100644 index 1b94f705..00000000 --- a/examples/sagemaker/gec.py +++ /dev/null @@ -1,6 +0,0 @@ -from ai21 import AI21SageMakerClient - -client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") -response = client.gec.create(text="roc and rolle") - -print(response.corrections[0].suggestion) diff --git a/examples/sagemaker/get_model_package_arn.py b/examples/sagemaker/get_model_package_arn.py deleted file mode 100644 index c7fdb7c4..00000000 --- a/examples/sagemaker/get_model_package_arn.py +++ /dev/null @@ -1,4 +0,0 @@ -from ai21 import SageMaker - -print(SageMaker.list_model_package_versions(model_name="j2-mid", region="us-east-1")) -print(SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1")) diff --git a/examples/sagemaker/paraphrase.py b/examples/sagemaker/paraphrase.py deleted file mode 100644 index 25d468a9..00000000 --- a/examples/sagemaker/paraphrase.py +++ /dev/null @@ -1,11 +0,0 @@ -from ai21 import AI21SageMakerClient - -client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") - - -response = client.paraphrase.create( - text="What's the difference between Scottish Fold and British?", - style="formal", -) - -print(response.suggestions[0].text) diff --git a/examples/sagemaker/summarization.py b/examples/sagemaker/summarization.py deleted file mode 100644 index cea5620a..00000000 --- a/examples/sagemaker/summarization.py +++ /dev/null @@ -1,15 +0,0 @@ -from ai21 import AI21SageMakerClient - -client = AI21SageMakerClient(endpoint_name="sm_endpoint_name") -response = client.summarize.create( - source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2]" - " From the 10th to the 16th century, " - "Holland proper was a unified political region within the Holy Roman Empire as a" - " county ruled by the counts of Holland. By the 17th century, " - "the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch " - "Republic.", - source_type="TEXT", -) - -print(response.summary) diff --git a/examples/studio/answer.py b/examples/studio/answer.py deleted file mode 100644 index 2d1a7c8a..00000000 --- a/examples/studio/answer.py +++ /dev/null @@ -1,13 +0,0 @@ -from ai21 import AI21Client - - -client = AI21Client() -response = client.answer.create( - context="Holland is a geographical region[2] and former province on the western coast of" - " the Netherlands.[2] From the " - "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " - "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " - "economic power, dominating the other provinces of the newly independent Dutch Republic.", - question="When did Holland become an economic power?", -) -print(response) diff --git a/examples/studio/async_answer.py b/examples/studio/async_answer.py deleted file mode 100644 index 295f9857..00000000 --- a/examples/studio/async_answer.py +++ /dev/null @@ -1,21 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - - -client = AsyncAI21Client() - - -async def main(): - response = await client.answer.create( - context="Holland is a geographical region[2] and former province on the western coast of" - " the Netherlands.[2] From the " - "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as " - "a county ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become " - "a maritime and economic power, dominating the other provinces of the newly independent Dutch Republic.", - question="When did Holland become an economic power?", - ) - print(response) - - -asyncio.run(main()) diff --git a/examples/studio/async_completion.py b/examples/studio/async_completion.py deleted file mode 100644 index 16eb85bb..00000000 --- a/examples/studio/async_completion.py +++ /dev/null @@ -1,84 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import Penalty - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AsyncAI21Client() - - -async def main(): - response = await client.completion.create( - prompt=prompt, - max_tokens=2, - model="j2-light", - temperature=0, - top_p=1, - top_k_return=0, - stop_sequences=["##"], - num_results=1, - custom_model=None, - epoch=1, - logit_bias={"▁I'm▁sorry": -100.0}, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - print(response) - print(response.completions[0].data.text) - print(response.prompt.tokens[0]["textRange"]["start"]) - - -asyncio.run(main()) diff --git a/examples/studio/async_custom_model.py b/examples/studio/async_custom_model.py deleted file mode 100644 index 428faefb..00000000 --- a/examples/studio/async_custom_model.py +++ /dev/null @@ -1,26 +0,0 @@ -import asyncio -import uuid - -from ai21 import AsyncAI21Client - - -client = AsyncAI21Client() - - -async def main(): - my_datasets = await client.dataset.list() - if len(my_datasets) > 0: - client.custom_model.create( - dataset_id=my_datasets[0].id, - model_name=f"test-{(str(uuid.uuid4()))[:20]}-asaf", - model_type="j2-mid", - ) - my_models = client.custom_model.list() - - print(my_models) - print(my_models[0]) - - print(client.custom_model.get(my_models[0].id)) - - -asyncio.run(main()) diff --git a/examples/studio/async_custom_model_completion.py b/examples/studio/async_custom_model_completion.py deleted file mode 100644 index 97346d29..00000000 --- a/examples/studio/async_custom_model_completion.py +++ /dev/null @@ -1,50 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AsyncAI21Client() - - -async def main(): - response = await client.completion.create( - model="j2-grande", - custom_model="test-6acbf857-d8fb-4bd9-a-asaf", - prompt=prompt, - max_tokens=2, - ) - print(response) - print(response.prompt.tokens[0]["textRange"]["start"]) - - -asyncio.run(main()) diff --git a/examples/studio/async_dataset.py b/examples/studio/async_dataset.py deleted file mode 100644 index 81c03838..00000000 --- a/examples/studio/async_dataset.py +++ /dev/null @@ -1,19 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - -file_path = "" - -client = AsyncAI21Client() - - -async def main(): - await client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") - result = await client.dataset.list() - print(result) - first_ds_id = result[0].id - result = await client.dataset.get(first_ds_id) - print(result) - - -asyncio.run(main()) diff --git a/examples/studio/async_embed.py b/examples/studio/async_embed.py deleted file mode 100644 index aa701cfa..00000000 --- a/examples/studio/async_embed.py +++ /dev/null @@ -1,17 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import EmbedType - -client = AsyncAI21Client() - - -async def main(): - response = await client.embed.create( - texts=["Holland is a geographical region[2] and former province on the western coast of the Netherlands."], - type=EmbedType.SEGMENT, - ) - print("embed: ", response.results[0].embedding) - - -asyncio.run(main()) diff --git a/examples/studio/async_gec.py b/examples/studio/async_gec.py deleted file mode 100644 index 75eb29ad..00000000 --- a/examples/studio/async_gec.py +++ /dev/null @@ -1,19 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - -client = AsyncAI21Client() - - -async def main(): - response = await client.gec.create(text="jazzz is a great stile of music") - - print("---------------------") - print(response.corrections[0].suggestion) - print(response.corrections[0].start_index) - print(response.corrections[0].end_index) - print(response.corrections[0].original_text) - print(response.corrections[0].correction_type) - - -asyncio.run(main()) diff --git a/examples/studio/async_improvements.py b/examples/studio/async_improvements.py deleted file mode 100644 index f3e5dc61..00000000 --- a/examples/studio/async_improvements.py +++ /dev/null @@ -1,24 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import ImprovementType - -client = AsyncAI21Client() - - -async def main(): - response = await client.improvements.create( - text="Affiliated with the profession of project management," - " I have ameliorated myself with a different set of hard skills as well as soft skills.", - types=[ImprovementType.FLUENCY], - ) - - print(response.improvements[0].original_text) - print(response.improvements[0].suggestions) - print(response.improvements[0].suggestions[0]) - print(response.improvements[0].improvement_type) - print(response.improvements[1].start_index) - print(response.improvements[1].end_index) - - -asyncio.run(main()) diff --git a/examples/studio/async_library_answer.py b/examples/studio/async_library_answer.py deleted file mode 100644 index c9f2f040..00000000 --- a/examples/studio/async_library_answer.py +++ /dev/null @@ -1,13 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - -client = AsyncAI21Client() - - -async def main(): - response = await client.library.answer.create(question="Can you tell me something about Holland?") - print(response) - - -asyncio.run(main()) diff --git a/examples/studio/async_library_search.py b/examples/studio/async_library_search.py deleted file mode 100644 index 8e59cd20..00000000 --- a/examples/studio/async_library_search.py +++ /dev/null @@ -1,13 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client - -client = AsyncAI21Client() - - -async def main(): - response = await client.library.search.create(query="cat colors", max_segments=2) - print(response) - - -asyncio.run(main()) diff --git a/examples/studio/async_paraphrase.py b/examples/studio/async_paraphrase.py deleted file mode 100644 index 28def7d4..00000000 --- a/examples/studio/async_paraphrase.py +++ /dev/null @@ -1,22 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import ParaphraseStyleType - -client = AsyncAI21Client() - - -async def main(): - response = await client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.GENERAL, - start_index=0, - end_index=20, - ) - - print(response.suggestions[0].text) - print(response.suggestions[1].text) - print(response.suggestions[2].text) - - -asyncio.run(main()) diff --git a/examples/studio/async_segmentation.py b/examples/studio/async_segmentation.py deleted file mode 100644 index 37961d06..00000000 --- a/examples/studio/async_segmentation.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import DocumentType - -client = AsyncAI21Client() - - -async def main(): - response = await client.segmentation.create( - source="Holland is a geographical region[2] and former province on the western " - "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " - "a unified political region within the Holy Roman Empire as a county ruled by the counts of" - " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, - ) - - print(response.segments[0].segment_text) - print(response.segments[0].segment_type) - - -asyncio.run(main()) diff --git a/examples/studio/async_summarize.py b/examples/studio/async_summarize.py deleted file mode 100644 index 81d0f64f..00000000 --- a/examples/studio/async_summarize.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import DocumentType, SummaryMethod - -client = AsyncAI21Client() - - -async def main(): - response = await client.summarize.create( - source="Holland is a geographical region[2] and former province on the western " - "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " - "a unified political region within the Holy Roman Empire as a county ruled by the counts of" - " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, - summary_length=SummaryMethod.SEGMENTS, - focus="Holland", - ) - print(response.summary) - - -asyncio.run(main()) diff --git a/examples/studio/async_summarize_by_segment.py b/examples/studio/async_summarize_by_segment.py deleted file mode 100644 index 5e59c694..00000000 --- a/examples/studio/async_summarize_by_segment.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -from ai21 import AsyncAI21Client -from ai21.models import DocumentType - -client = AsyncAI21Client() - - -async def main(): - response = await client.summarize_by_segment.create( - source="Holland is a geographical region[2] and former province on the western coast of " - "the Netherlands.[2] From the 10th to the 16th century, " - "Holland proper was a unified political region within the Holy Roman Empire as a " - "county ruled by the counts of Holland. By the 17th century, " - "the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, - focus="Holland", - ) - print(response) - - -asyncio.run(main()) diff --git a/examples/studio/completion.py b/examples/studio/completion.py deleted file mode 100644 index 1f21483a..00000000 --- a/examples/studio/completion.py +++ /dev/null @@ -1,76 +0,0 @@ -from ai21 import AI21Client -from ai21.models import Penalty - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AI21Client() -response = client.completion.create( - prompt=prompt, - max_tokens=2, - model="j2-light", - temperature=0, - top_p=1, - top_k_return=0, - stop_sequences=["##"], - num_results=1, - custom_model=None, - epoch=1, - logit_bias={"▁I'm▁sorry": -100.0}, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), -) - -print(response) -print(response.completions[0].data.text) -print(response.prompt.tokens[0]["textRange"]["start"]) diff --git a/examples/studio/custom_model.py b/examples/studio/custom_model.py deleted file mode 100644 index d1f7b7e1..00000000 --- a/examples/studio/custom_model.py +++ /dev/null @@ -1,19 +0,0 @@ -import uuid - -from ai21 import AI21Client - - -client = AI21Client() -my_datasets = client.dataset.list() -if len(my_datasets) > 0: - client.custom_model.create( - dataset_id=my_datasets[0].id, - model_name=f"test-{(str(uuid.uuid4()))[:20]}-asaf", - model_type="j2-mid", - ) - my_models = client.custom_model.list() - - print(my_models) - print(my_models[0]) - - print(client.custom_model.get(my_models[0].id)) diff --git a/examples/studio/custom_model_completion.py b/examples/studio/custom_model_completion.py deleted file mode 100644 index f06f21cf..00000000 --- a/examples/studio/custom_model_completion.py +++ /dev/null @@ -1,42 +0,0 @@ -from ai21 import AI21Client - -prompt = ( - "The following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- " - "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" - " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " - "Hi there, happy to help!\nUser: Is there no way to return a product?" - " I got your blue T-Shirt size small but it doesn't fit.\n" - "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" - "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" - "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" - " associate called Max. Max is very kind and keen to help. The following are important points about" - " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" - 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' - "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" - " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" - "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" - " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." - " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" - " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" - " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" - "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" - "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" - "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" - " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" - " are important points about the business policies:\n- Delivery takes up to 5 days\n" - "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" - "User: Hi, I have a question for you" -) - -client = AI21Client() -response = client.completion.create( - model="j2-grande", - custom_model="test-6acbf857-d8fb-4bd9-a-asaf", - prompt=prompt, - max_tokens=2, -) -print(response) -print(response.prompt.tokens[0]["textRange"]["start"]) diff --git a/examples/studio/dataset.py b/examples/studio/dataset.py deleted file mode 100644 index 87e587cc..00000000 --- a/examples/studio/dataset.py +++ /dev/null @@ -1,11 +0,0 @@ -from ai21 import AI21Client - -file_path = "" - -client = AI21Client() -client.dataset.create(file_path=file_path, dataset_name="my_new_ds_name") -result = client.dataset.list() -print(result) -first_ds_id = result[0].id -result = client.dataset.get(first_ds_id) -print(result) diff --git a/examples/studio/embed.py b/examples/studio/embed.py deleted file mode 100644 index f0dc5a17..00000000 --- a/examples/studio/embed.py +++ /dev/null @@ -1,9 +0,0 @@ -from ai21 import AI21Client -from ai21.models import EmbedType - -client = AI21Client() -response = client.embed.create( - texts=["Holland is a geographical region[2] and former province on the western coast of the Netherlands."], - type=EmbedType.SEGMENT, -) -print("embed: ", response.results[0].embedding) diff --git a/examples/studio/gec.py b/examples/studio/gec.py deleted file mode 100644 index b914a915..00000000 --- a/examples/studio/gec.py +++ /dev/null @@ -1,11 +0,0 @@ -from ai21 import AI21Client - -client = AI21Client() -response = client.gec.create(text="jazzz is a great stile of music") - -print("---------------------") -print(response.corrections[0].suggestion) -print(response.corrections[0].start_index) -print(response.corrections[0].end_index) -print(response.corrections[0].original_text) -print(response.corrections[0].correction_type) diff --git a/examples/studio/improvements.py b/examples/studio/improvements.py deleted file mode 100644 index f75dea58..00000000 --- a/examples/studio/improvements.py +++ /dev/null @@ -1,16 +0,0 @@ -from ai21 import AI21Client -from ai21.models import ImprovementType - -client = AI21Client() -response = client.improvements.create( - text="Affiliated with the profession of project management," - " I have ameliorated myself with a different set of hard skills as well as soft skills.", - types=[ImprovementType.FLUENCY], -) - -print(response.improvements[0].original_text) -print(response.improvements[0].suggestions) -print(response.improvements[0].suggestions[0]) -print(response.improvements[0].improvement_type) -print(response.improvements[1].start_index) -print(response.improvements[1].end_index) diff --git a/examples/studio/library_answer.py b/examples/studio/library_answer.py deleted file mode 100644 index 1c2ae02f..00000000 --- a/examples/studio/library_answer.py +++ /dev/null @@ -1,5 +0,0 @@ -from ai21 import AI21Client - -client = AI21Client() -response = client.library.answer.create(question="Can you tell me something about Holland?") -print(response) diff --git a/examples/studio/library_search.py b/examples/studio/library_search.py deleted file mode 100644 index 20fcf0a1..00000000 --- a/examples/studio/library_search.py +++ /dev/null @@ -1,5 +0,0 @@ -from ai21 import AI21Client - -client = AI21Client() -response = client.library.search.create(query="cat colors", max_segments=2) -print(response) diff --git a/examples/studio/paraphrase.py b/examples/studio/paraphrase.py deleted file mode 100644 index 55c3f5c2..00000000 --- a/examples/studio/paraphrase.py +++ /dev/null @@ -1,14 +0,0 @@ -from ai21 import AI21Client -from ai21.models import ParaphraseStyleType - -client = AI21Client() -response = client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.GENERAL, - start_index=0, - end_index=20, -) - -print(response.suggestions[0].text) -print(response.suggestions[1].text) -print(response.suggestions[2].text) diff --git a/examples/studio/segmentation.py b/examples/studio/segmentation.py deleted file mode 100644 index bf2207cb..00000000 --- a/examples/studio/segmentation.py +++ /dev/null @@ -1,16 +0,0 @@ -from ai21 import AI21Client -from ai21.models import DocumentType - -client = AI21Client() - -response = client.segmentation.create( - source="Holland is a geographical region[2] and former province on the western " - "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " - "a unified political region within the Holy Roman Empire as a county ruled by the counts of" - " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, -) - -print(response.segments[0].segment_text) -print(response.segments[0].segment_type) diff --git a/examples/studio/summarize.py b/examples/studio/summarize.py deleted file mode 100644 index 54b9d7af..00000000 --- a/examples/studio/summarize.py +++ /dev/null @@ -1,15 +0,0 @@ -from ai21 import AI21Client -from ai21.models import DocumentType, SummaryMethod - -client = AI21Client() -response = client.summarize.create( - source="Holland is a geographical region[2] and former province on the western " - "coast of the Netherlands.[2] From the 10th to the 16th century, Holland proper was " - "a unified political region within the Holy Roman Empire as a county ruled by the counts of" - " Holland. By the 17th century, the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, - summary_length=SummaryMethod.SEGMENTS, - focus="Holland", -) -print(response.summary) diff --git a/examples/studio/summarize_by_segment.py b/examples/studio/summarize_by_segment.py deleted file mode 100644 index e2d2a31d..00000000 --- a/examples/studio/summarize_by_segment.py +++ /dev/null @@ -1,15 +0,0 @@ -from ai21 import AI21Client -from ai21.models import DocumentType - -client = AI21Client() -response = client.summarize_by_segment.create( - source="Holland is a geographical region[2] and former province on the western coast of " - "the Netherlands.[2] From the 10th to the 16th century, " - "Holland proper was a unified political region within the Holy Roman Empire as a " - "county ruled by the counts of Holland. By the 17th century, " - "the province of Holland had risen to become a maritime and economic power," - " dominating the other provinces of the newly independent Dutch Republic.", - source_type=DocumentType.TEXT, - focus="Holland", -) -print(response) diff --git a/tests/integration_tests/clients/studio/test_answer.py b/tests/integration_tests/clients/studio/test_answer.py deleted file mode 100644 index 29f540ee..00000000 --- a/tests/integration_tests/clients/studio/test_answer.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from ai21 import AI21Client, AsyncAI21Client - -_CONTEXT = ( - "Holland is a geographical region[2] and former province on the western coast of" - " the Netherlands. From the " - "10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county " - "ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and " - "economic power, dominating the other provinces of the newly independent Dutch Republic." -) - - -@pytest.mark.parametrize( - ids=[ - "when_answer_is_in_context", - "when_answer_not_in_context", - ], - argnames=["question", "is_answer_in_context", "expected_answer_type"], - argvalues=[ - ("When did Holland become an economic power?", True, str), - ("Is the ocean blue?", False, None), - ], -) -def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): - client = AI21Client() - response = client.answer.create( - context=_CONTEXT, - question=question, - ) - - assert response.answer_in_context == is_answer_in_context - if is_answer_in_context: - assert isinstance(response.answer, str) - else: - assert response.answer is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_answer_is_in_context", - "when_answer_not_in_context", - ], - argnames=["question", "is_answer_in_context", "expected_answer_type"], - argvalues=[ - ("When did Holland become an economic power?", True, str), - ("Is the ocean blue?", False, None), - ], -) -async def test_async_answer(question: str, is_answer_in_context: bool, expected_answer_type: type): - client = AsyncAI21Client() - response = await client.answer.create( - context=_CONTEXT, - question=question, - ) - - assert response.answer_in_context == is_answer_in_context - if is_answer_in_context: - assert isinstance(response.answer, str) - else: - assert response.answer is None diff --git a/tests/integration_tests/clients/studio/test_completion.py b/tests/integration_tests/clients/studio/test_completion.py deleted file mode 100644 index 0138f640..00000000 --- a/tests/integration_tests/clients/studio/test_completion.py +++ /dev/null @@ -1,265 +0,0 @@ -import pytest - -from typing import Dict -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import Penalty - -_PROMPT = """ -User: Haven't received a confirmation email for my order #12345. -Assistant: I'm sorry to hear that. I'll look into it right away. -User: Can you please let me know when I can expect to receive it? -""" - - -def test_completion(): - num_results = 3 - - client = AI21Client() - response = client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model="j2-ultra", - temperature=0.7, - top_p=0.2, - top_k_return=0.2, - stop_sequences=["##"], - num_results=num_results, - custom_model=None, - epoch=1, - logit_bias={"▁a▁box▁of": -100.0}, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - assert response.prompt.text == _PROMPT - assert len(response.completions) == num_results - # Check the results aren't all the same - assert len([completion.data.text for completion in response.completions]) == num_results - for completion in response.completions: - assert isinstance(completion.data.text, str) - - -def test_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): - num_results = 5 - - client = AI21Client() - response = client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model="j2-ultra", - temperature=1, - top_p=0, - top_k_return=0, - num_results=num_results, - epoch=1, - ) - - assert response.prompt.text == _PROMPT - assert len(response.completions) == num_results - # Verify all results are the same - assert len(set([completion.data.text for completion in response.completions])) == 1 - - -@pytest.mark.parametrize( - ids=[ - "finish_reason_length", - "finish_reason_endoftext", - "finish_reason_stop_sequence", - ], - argnames=["max_tokens", "stop_sequences", "reason"], - argvalues=[ - (10, "##", "length"), - (100, "##", "endoftext"), - (50, "\n", "stop"), - ], -) -def test_completion_when_finish_reason_defined__should_halt_on_expected_reason( - max_tokens: int, stop_sequences: str, reason: str -): - client = AI21Client() - response = client.completion.create( - prompt=_PROMPT, - max_tokens=max_tokens, - model="j2-ultra", - temperature=1, - top_p=0, - num_results=1, - stop_sequences=[stop_sequences], - top_k_return=0, - epoch=1, - ) - - assert response.completions[0].finish_reason.reason == reason - - -@pytest.mark.parametrize( - ids=[ - "no_logit_bias", - "logit_bias_negative", - ], - argnames=["expected_result", "logit_bias"], - argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})], -) -def test_completion_logit_bias__should_impact_on_response(expected_result: str, logit_bias: Dict[str, float]): - client = AI21Client() - response = client.completion.create( - prompt="Life is like", - max_tokens=3, - model="j2-ultra", - temperature=0, - logit_bias=logit_bias, - ) - - assert response.completions[0].data.text.strip() == expected_result.strip() - - -@pytest.mark.asyncio -async def test_async_completion(): - num_results = 3 - - client = AsyncAI21Client() - response = await client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model="j2-ultra", - temperature=0.7, - top_p=0.2, - top_k_return=0.2, - stop_sequences=["##"], - num_results=num_results, - custom_model=None, - epoch=1, - logit_bias={"▁a▁box▁of": -100.0}, - count_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - frequency_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - presence_penalty=Penalty( - scale=0, - apply_to_emojis=False, - apply_to_numbers=False, - apply_to_stopwords=False, - apply_to_punctuation=False, - apply_to_whitespaces=False, - ), - ) - - assert response.prompt.text == _PROMPT - assert len(response.completions) == num_results - # Check the results aren't all the same - assert len([completion.data.text for completion in response.completions]) == num_results - for completion in response.completions: - assert isinstance(completion.data.text, str) - - -@pytest.mark.asyncio -async def test_async_completion_when_temperature_1_and_top_p_is_0__should_return_same_response(): - num_results = 5 - - client = AsyncAI21Client() - response = await client.completion.create( - prompt=_PROMPT, - max_tokens=64, - model="j2-ultra", - temperature=1, - top_p=0, - top_k_return=0, - num_results=num_results, - epoch=1, - ) - - assert response.prompt.text == _PROMPT - assert len(response.completions) == num_results - # Verify all results are the same - assert len(set([completion.data.text for completion in response.completions])) == 1 - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "finish_reason_length", - "finish_reason_endoftext", - "finish_reason_stop_sequence", - ], - argnames=["max_tokens", "stop_sequences", "reason"], - argvalues=[ - (10, "##", "length"), - (100, "##", "endoftext"), - (50, "\n", "stop"), - ], -) -async def test_async_completion_when_finish_reason_defined__should_halt_on_expected_reason( - max_tokens: int, stop_sequences: str, reason: str -): - client = AsyncAI21Client() - response = await client.completion.create( - prompt=_PROMPT, - max_tokens=max_tokens, - model="j2-ultra", - temperature=1, - top_p=0, - num_results=1, - stop_sequences=[stop_sequences], - top_k_return=0, - epoch=1, - ) - - assert response.completions[0].finish_reason.reason == reason - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "no_logit_bias", - "logit_bias_negative", - ], - argnames=["expected_result", "logit_bias"], - argvalues=[(" a box of chocolates", None), (" riding a bicycle", {"▁a▁box▁of": -100.0})], -) -async def test_async_completion_logit_bias__should_impact_on_response( - expected_result: str, logit_bias: Dict[str, float] -): - client = AsyncAI21Client() - response = await client.completion.create( - prompt="Life is like", - max_tokens=3, - model="j2-ultra", - temperature=0, - logit_bias=logit_bias, - ) - - assert response.completions[0].data.text.strip() == expected_result.strip() diff --git a/tests/integration_tests/clients/studio/test_embed.py b/tests/integration_tests/clients/studio/test_embed.py deleted file mode 100644 index 8c68179c..00000000 --- a/tests/integration_tests/clients/studio/test_embed.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import List - -import pytest -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import EmbedType - -_TEXT_0 = "Holland is a geographical region and former province on the western coast of the Netherlands." -_TEXT_1 = "Germany is a country in Central Europe. It is the second-most populous country in Europe after Russia" - -_SEGMENT_0 = "The sun sets behind the mountains," -_SEGMENT_1 = "casting a warm glow over" -_SEGMENT_2 = "the city of Amsterdam." - - -@pytest.mark.parametrize( - ids=[ - "when_single_text_and_query__should_return_single_embedding", - "when_multiple_text_and_query__should_return_multiple_embeddings", - "when_single_text_and_segment__should_return_single_embedding", - "when_multiple_text_and_segment__should_return_multiple_embeddings", - ], - argnames=["texts", "type"], - argvalues=[ - ([_TEXT_0], EmbedType.QUERY), - ([_TEXT_0, _TEXT_1], EmbedType.QUERY), - ([_SEGMENT_0], EmbedType.SEGMENT), - ([_SEGMENT_0, _SEGMENT_1, _SEGMENT_2], EmbedType.SEGMENT), - ], -) -def test_embed(texts: List[str], type: EmbedType): - client = AI21Client() - response = client.embed.create( - texts=texts, - type=type, - ) - - assert len(response.results) == len(texts) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_single_text_and_query__should_return_single_embedding", - "when_multiple_text_and_query__should_return_multiple_embeddings", - "when_single_text_and_segment__should_return_single_embedding", - "when_multiple_text_and_segment__should_return_multiple_embeddings", - ], - argnames=["texts", "type"], - argvalues=[ - ([_TEXT_0], EmbedType.QUERY), - ([_TEXT_0, _TEXT_1], EmbedType.QUERY), - ([_SEGMENT_0], EmbedType.SEGMENT), - ([_SEGMENT_0, _SEGMENT_1, _SEGMENT_2], EmbedType.SEGMENT), - ], -) -async def test_async_embed(texts: List[str], type: EmbedType): - client = AsyncAI21Client() - response = await client.embed.create( - texts=texts, - type=type, - ) - - assert len(response.results) == len(texts) diff --git a/tests/integration_tests/clients/studio/test_gec.py b/tests/integration_tests/clients/studio/test_gec.py deleted file mode 100644 index 5e9608f2..00000000 --- a/tests/integration_tests/clients/studio/test_gec.py +++ /dev/null @@ -1,60 +0,0 @@ -import pytest -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import CorrectionType - - -@pytest.mark.parametrize( - ids=[ - "should_fix_spelling", - "should_fix_grammar", - "should_fix_missing_word", - "should_fix_punctuation", - "should_fix_wrong_word", - ], - argnames=["text", "correction_type", "expected_suggestion"], - argvalues=[ - ("jazzz is music", CorrectionType.SPELLING, "Jazz"), - ("You am nice", CorrectionType.GRAMMAR, "are"), - ( - "He stared out the window, lost in thought, as the raindrops against the glass.", - CorrectionType.MISSING_WORD, - "raindrops fell against", - ), - ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), - ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), - ], -) -def test_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): - client = AI21Client() - response = client.gec.create(text=text) - assert response.corrections[0].suggestion == expected_suggestion - assert response.corrections[0].correction_type == correction_type - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "should_fix_spelling", - "should_fix_grammar", - "should_fix_missing_word", - "should_fix_punctuation", - "should_fix_wrong_word", - ], - argnames=["text", "correction_type", "expected_suggestion"], - argvalues=[ - ("jazzz is music", CorrectionType.SPELLING, "Jazz"), - ("You am nice", CorrectionType.GRAMMAR, "are"), - ( - "He stared out the window, lost in thought, as the raindrops against the glass.", - CorrectionType.MISSING_WORD, - "raindrops fell against", - ), - ("He is a well known author.", CorrectionType.PUNCTUATION, "well-known"), - ("He is a dog-known author.", CorrectionType.WRONG_WORD, "well-known"), - ], -) -async def test_async_gec(text: str, correction_type: CorrectionType, expected_suggestion: str): - client = AsyncAI21Client() - response = await client.gec.create(text=text) - assert response.corrections[0].suggestion == expected_suggestion - assert response.corrections[0].correction_type == correction_type diff --git a/tests/integration_tests/clients/studio/test_improvements.py b/tests/integration_tests/clients/studio/test_improvements.py deleted file mode 100644 index e30bb159..00000000 --- a/tests/integration_tests/clients/studio/test_improvements.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest - -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import ImprovementType - - -def test_improvements(): - client = AI21Client() - response = client.improvements.create( - text="Affiliated with the profession of project management," - " I have ameliorated myself with a different set of hard skills as well as soft skills.", - types=[ImprovementType.FLUENCY], - ) - - assert len(response.improvements) > 0 - - -@pytest.mark.asyncio -async def test_async_improvements(): - client = AsyncAI21Client() - response = await client.improvements.create( - text="Affiliated with the profession of project management," - " I have ameliorated myself with a different set of hard skills as well as soft skills.", - types=[ImprovementType.FLUENCY], - ) - - assert len(response.improvements) > 0 diff --git a/tests/integration_tests/clients/studio/test_library_answer.py b/tests/integration_tests/clients/studio/test_library_answer.py deleted file mode 100644 index 8f125a4f..00000000 --- a/tests/integration_tests/clients/studio/test_library_answer.py +++ /dev/null @@ -1,38 +0,0 @@ -import pytest -from ai21 import AI21Client, AsyncAI21Client - - -@pytest.mark.skipif -def test_library_answer__when_answer_not_in_context__should_return_false(file_in_library: str): - client = AI21Client() - response = client.library.answer.create(question="Who is Tony Stark?") - assert response.answer is None - assert not response.answer_in_context - - -@pytest.mark.skipif -def test_library_answer__when_answer_in_context__should_return_true(file_in_library: str): - client = AI21Client() - response = client.library.answer.create(question="Who was Albert Einstein?") - assert response.answer is not None - assert response.answer_in_context - assert response.sources[0].file_id == file_in_library - - -@pytest.mark.skipif -@pytest.mark.asyncio -async def test_async_library_answer__when_answer_not_in_context__should_return_false(file_in_library: str): - client = AsyncAI21Client() - response = await client.library.answer.create(question="Who is Tony Stark?") - assert response.answer is None - assert not response.answer_in_context - - -@pytest.mark.skipif -@pytest.mark.asyncio -async def test_async_library_answer__when_answer_in_context__should_return_true(file_in_library: str): - client = AsyncAI21Client() - response = await client.library.answer.create(question="Who was Albert Einstein?") - assert response.answer is not None - assert response.answer_in_context - assert response.sources[0].file_id == file_in_library diff --git a/tests/integration_tests/clients/studio/test_library_search.py b/tests/integration_tests/clients/studio/test_library_search.py deleted file mode 100644 index 339a945f..00000000 --- a/tests/integration_tests/clients/studio/test_library_search.py +++ /dev/null @@ -1,25 +0,0 @@ -from ai21 import AI21Client, AsyncAI21Client -import pytest - - -@pytest.mark.skipif -def test_library_search__when_search__should_return_relevant_results(file_in_library: str): - client = AI21Client() - response = client.library.search.create( - query="What did Albert Einstein get a Nobel Prize for?", labels=["einstein"] - ) - assert len(response.results) > 0 - for result in response.results: - assert result.file_id == file_in_library - - -@pytest.mark.skipif -@pytest.mark.asyncio -async def test_async_library_search__when_search__should_return_relevant_results(file_in_library: str): - client = AsyncAI21Client() - response = await client.library.search.create( - query="What did Albert Einstein get a Nobel Prize for?", labels=["einstein"] - ) - assert len(response.results) > 0 - for result in response.results: - assert result.file_id == file_in_library diff --git a/tests/integration_tests/clients/studio/test_paraphrase.py b/tests/integration_tests/clients/studio/test_paraphrase.py deleted file mode 100644 index def6c937..00000000 --- a/tests/integration_tests/clients/studio/test_paraphrase.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest - -from ai21 import AI21Client, AsyncAI21Client -from ai21.models import ParaphraseStyleType - - -def test_paraphrase(): - client = AI21Client() - response = client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.FORMAL, - start_index=0, - end_index=20, - ) - for suggestion in response.suggestions: - print(suggestion.text) - assert len(response.suggestions) > 0 - - -def test_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): - client = AI21Client() - response = client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.GENERAL, - start_index=0, - end_index=5, - ) - assert len(response.suggestions) == 0 - - -@pytest.mark.parametrize( - ids=["when_general", "when_casual", "when_long", "when_short", "when_formal"], - argnames=["style"], - argvalues=[ - (ParaphraseStyleType.GENERAL,), - (ParaphraseStyleType.CASUAL,), - (ParaphraseStyleType.LONG,), - (ParaphraseStyleType.SHORT,), - (ParaphraseStyleType.FORMAL,), - ], -) -def test_paraphrase_styles(style: ParaphraseStyleType): - client = AI21Client() - response = client.paraphrase.create( - text="Today is a beautiful day.", - style=style, - start_index=0, - end_index=25, - ) - - assert len(response.suggestions) > 0 - - -@pytest.mark.asyncio -async def test_async_paraphrase(): - client = AsyncAI21Client() - response = await client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.FORMAL, - start_index=0, - end_index=20, - ) - for suggestion in response.suggestions: - print(suggestion.text) - assert len(response.suggestions) > 0 - - -@pytest.mark.asyncio -async def test_async_paraphrase__when_start_and_end_index_is_small__should_not_return_suggestions(): - client = AsyncAI21Client() - response = await client.paraphrase.create( - text="The cat (Felis catus) is a domestic species of small carnivorous mammal", - style=ParaphraseStyleType.GENERAL, - start_index=0, - end_index=5, - ) - assert len(response.suggestions) == 0 diff --git a/tests/integration_tests/clients/studio/test_segmentation.py b/tests/integration_tests/clients/studio/test_segmentation.py deleted file mode 100644 index c521be88..00000000 --- a/tests/integration_tests/clients/studio/test_segmentation.py +++ /dev/null @@ -1,100 +0,0 @@ -import pytest -from ai21 import AI21Client, AsyncAI21Client -from ai21.errors import UnprocessableEntity -from ai21.models import DocumentType - -_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the -Netherlands. From the 10th to the 16th century, Holland proper was a unified political - region within the Holy Roman Empire as a county ruled by the counts of Holland. - By the 17th century, the province of Holland had risen to become a maritime and economic power, - dominating the other provinces of the newly independent Dutch Republic.""" - -_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" - - -@pytest.mark.parametrize( - ids=[ - "when_source_is_text__should_return_a_segments", - "when_source_is_url__should_return_a_segments", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.TEXT), - (_SOURCE_URL, DocumentType.URL), - ], -) -def test_segmentation(source: str, source_type: DocumentType): - client = AI21Client() - - response = client.segmentation.create( - source=source, - source_type=source_type, - ) - - assert isinstance(response.segments[0].segment_text, str) - assert response.segments[0].segment_type is not None - - -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - # "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - # (_SOURCE_URL, DocumentType.TEXT), - ], -) -def test_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): - client = AI21Client() - with pytest.raises(UnprocessableEntity): - client.segmentation.create( - source=source, - source_type=source_type, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_source_is_text__should_return_a_segments", - "when_source_is_url__should_return_a_segments", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.TEXT), - (_SOURCE_URL, DocumentType.URL), - ], -) -async def test_async_segmentation(source: str, source_type: DocumentType): - client = AsyncAI21Client() - - response = await client.segmentation.create( - source=source, - source_type=source_type, - ) - - assert isinstance(response.segments[0].segment_text, str) - assert response.segments[0].segment_type is not None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - # "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - # (_SOURCE_URL, DocumentType.TEXT), - ], -) -async def test_async_segmentation__source_and_source_type_misalignment(source: str, source_type: DocumentType): - client = AsyncAI21Client() - with pytest.raises(UnprocessableEntity): - await client.segmentation.create( - source=source, - source_type=source_type, - ) diff --git a/tests/integration_tests/clients/studio/test_summarize.py b/tests/integration_tests/clients/studio/test_summarize.py deleted file mode 100644 index fb4788d7..00000000 --- a/tests/integration_tests/clients/studio/test_summarize.py +++ /dev/null @@ -1,113 +0,0 @@ -import pytest - -from ai21 import AI21Client, AsyncAI21Client -from ai21.errors import UnprocessableEntity -from ai21.models import DocumentType, SummaryMethod - -_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the -Netherlands. From the 10th to the 16th century, Holland proper was a unified political - region within the Holy Roman Empire as a county ruled by the counts of Holland. - By the 17th century, the province of Holland had risen to become a maritime and economic power, - dominating the other provinces of the newly independent Dutch Republic.""" - -_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" - - -@pytest.mark.parametrize( - ids=[ - "when_source_is_text__should_return_a_suggestion", - "when_source_is_url__should_return_a_suggestion", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.TEXT), - (_SOURCE_URL, DocumentType.URL), - ], -) -def test_summarize(source: str, source_type: DocumentType): - focus = "Holland" - - client = AI21Client() - response = client.summarize.create( - source=source, - source_type=source_type, - summary_method=SummaryMethod.SEGMENTS, - focus=focus, - ) - assert response.summary is not None - assert focus in response.summary - - -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - (_SOURCE_URL, DocumentType.TEXT), - ], -) -def test_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): - focus = "Holland" - - client = AI21Client() - with pytest.raises(UnprocessableEntity): - client.summarize.create( - source=source, - source_type=source_type, - summary_method=SummaryMethod.SEGMENTS, - focus=focus, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_source_is_text__should_return_a_suggestion", - "when_source_is_url__should_return_a_suggestion", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.TEXT), - (_SOURCE_URL, DocumentType.URL), - ], -) -async def test_async_summarize(source: str, source_type: DocumentType): - focus = "Holland" - - client = AsyncAI21Client() - response = await client.summarize.create( - source=source, - source_type=source_type, - summary_method=SummaryMethod.SEGMENTS, - focus=focus, - ) - assert response.summary is not None - assert focus in response.summary - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - (_SOURCE_URL, DocumentType.TEXT), - ], -) -async def test_async_summarize__source_and_source_type_misalignment(source: str, source_type: DocumentType): - focus = "Holland" - - client = AsyncAI21Client() - with pytest.raises(UnprocessableEntity): - await client.summarize.create( - source=source, - source_type=source_type, - summary_method=SummaryMethod.SEGMENTS, - focus=focus, - ) diff --git a/tests/integration_tests/clients/studio/test_summarize_by_segment.py b/tests/integration_tests/clients/studio/test_summarize_by_segment.py deleted file mode 100644 index dfd477c1..00000000 --- a/tests/integration_tests/clients/studio/test_summarize_by_segment.py +++ /dev/null @@ -1,122 +0,0 @@ -import pytest - -from ai21 import AI21Client, AsyncAI21Client -from ai21.errors import UnprocessableEntity -from ai21.models import DocumentType - -_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the Netherlands. - From the 10th to the 16th century, Holland proper was a unified political - region within the Holy Roman Empire as a county ruled by the counts of Holland. - By the 17th century, the province of Holland had risen to become a maritime and economic power, - dominating the other provinces of the newly independent Dutch Republic.""" - -_SOURCE_URL = "https://en.wikipedia.org/wiki/Holland" - - -def test_summarize_by_segment__when_text__should_return_response(): - client = AI21Client() - response = client.summarize_by_segment.create( - source=_SOURCE_TEXT, - source_type=DocumentType.TEXT, - focus="Holland", - ) - assert isinstance(response.segments[0].segment_text, str) - assert response.segments[0].segment_html is None - assert isinstance(response.segments[0].summary, str) - assert len(response.segments[0].highlights) > 0 - assert response.segments[0].segment_type == "normal_text" - assert response.segments[0].has_summary - - -def test_summarize_by_segment__when_url__should_return_response(): - client = AI21Client() - response = client.summarize_by_segment.create( - source=_SOURCE_URL, - source_type=DocumentType.URL, - focus="Holland", - ) - assert isinstance(response.segments[0].segment_text, str) - assert isinstance(response.segments[0].segment_html, str) - assert isinstance(response.segments[0].summary, str) - assert response.segments[0].segment_type == "normal_text" - assert len(response.segments[0].highlights) > 0 - assert response.segments[0].has_summary - - -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - (_SOURCE_URL, DocumentType.TEXT), - ], -) -def test_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): - focus = "Holland" - - client = AI21Client() - with pytest.raises(UnprocessableEntity): - client.summarize_by_segment.create( - source=source, - source_type=source_type, - focus=focus, - ) - - -@pytest.mark.asyncio -async def test_async_summarize_by_segment__when_text__should_return_response(): - client = AsyncAI21Client() - response = await client.summarize_by_segment.create( - source=_SOURCE_TEXT, - source_type=DocumentType.TEXT, - focus="Holland", - ) - assert isinstance(response.segments[0].segment_text, str) - assert response.segments[0].segment_html is None - assert isinstance(response.segments[0].summary, str) - assert len(response.segments[0].highlights) > 0 - assert response.segments[0].segment_type == "normal_text" - assert response.segments[0].has_summary - - -@pytest.mark.asyncio -async def test_async_summarize_by_segment__when_url__should_return_response(): - client = AsyncAI21Client() - response = await client.summarize_by_segment.create( - source=_SOURCE_URL, - source_type=DocumentType.URL, - focus="Holland", - ) - assert isinstance(response.segments[0].segment_text, str) - assert isinstance(response.segments[0].segment_html, str) - assert isinstance(response.segments[0].summary, str) - assert response.segments[0].segment_type == "normal_text" - assert len(response.segments[0].highlights) > 0 - assert response.segments[0].has_summary - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - ids=[ - "when_source_is_text_and_source_type_is_url__should_raise_error", - "when_source_is_url_and_source_type_is_text__should_raise_error", - ], - argnames=["source", "source_type"], - argvalues=[ - (_SOURCE_TEXT, DocumentType.URL), - (_SOURCE_URL, DocumentType.TEXT), - ], -) -async def test_async_summarize_by_segment__source_and_source_type_misalignment(source: str, source_type: DocumentType): - focus = "Holland" - - client = AsyncAI21Client() - with pytest.raises(UnprocessableEntity): - await client.summarize_by_segment.create( - source=source, - source_type=source_type, - focus=focus, - ) diff --git a/tests/integration_tests/clients/test_sagemaker.py b/tests/integration_tests/clients/test_sagemaker.py deleted file mode 100644 index c1d8f3d8..00000000 --- a/tests/integration_tests/clients/test_sagemaker.py +++ /dev/null @@ -1,43 +0,0 @@ -import subprocess - -import pytest -from pathlib import Path - -SAGEMAKER_DIR = "sagemaker" - -SAGEMAKER_PATH = Path(__file__).parent.parent.parent.parent / "examples" / SAGEMAKER_DIR - - -@pytest.mark.skip(reason="SageMaker integration tests need endpoints to be running") -@pytest.mark.parametrize( - argnames=["test_file_name"], - argvalues=[ - ("answer.py",), - ("async_answer.py",), - ("completion.py",), - ("async_completion.py",), - ("gec.py",), - ("async_gec.py",), - ("paraphrase.py",), - ("async_paraphrase.py",), - ("summarization.py",), - ("async_summarization.py",), - ], - ids=[ - "when_answer__should_return_ok", - "when_async_answer__should_return_ok", - "when_completion__should_return_ok", - "when_async_completion__should_return_ok", - "when_gec__should_return_ok", - "when_async_gec__should_return_ok", - "when_paraphrase__should_return_ok", - "when_async_paraphrase__should_return_ok", - "when_summarization__should_return_ok", - "when_async_summarization__should_return_ok", - ], -) -def test_sagemaker(test_file_name: str): - file_path = SAGEMAKER_PATH / test_file_name - print(f"About to run: {file_path}") - exit_code = subprocess.call(["python", file_path]) - assert exit_code == 0, f"failed to run {test_file_name}" diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 945d8045..5cd0d4dd 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -17,15 +17,6 @@ @pytest.mark.parametrize( argnames=["test_file_name"], argvalues=[ - ("answer.py",), - ("completion.py",), - ("embed.py",), - ("gec.py",), - ("improvements.py",), - ("paraphrase.py",), - ("segmentation.py",), - ("summarize.py",), - ("summarize_by_segment.py",), ("tokenization.py",), ("chat/chat_completions.py",), ("chat/chat_completions_jamba_instruct.py",), @@ -34,22 +25,8 @@ ("chat/chat_function_calling.py",), ("chat/chat_function_calling_multiple_tools.py",), ("chat/chat_response_format.py",), - # ("custom_model.py", ), - # ('custom_model_completion.py', ), - # ("dataset.py", ), - # ("library.py", ), - # ("library_answer.py", ), ], ids=[ - "when_answer__should_return_ok", - "when_completion__should_return_ok", - "when_embed__should_return_ok", - "when_gec__should_return_ok", - "when_improvements__should_return_ok", - "when_paraphrase__should_return_ok", - "when_segmentation__should_return_ok", - "when_summarize__should_return_ok", - "when_summarize_by_segment__should_return_ok", "when_tokenization__should_return_ok", "when_chat_completions__should_return_ok", "when_chat_completions_jamba_instruct__should_return_ok", @@ -58,11 +35,6 @@ "when_chat_completions_with_function_calling__should_return_ok", "when_chat_completions_with_function_calling_multiple_tools_should_return_ok", "when_chat_completions_with_response_format__should_return_ok", - # "when_custom_model__should_return_ok", - # "when_custom_model_completion__should_return_ok", - # "when_dataset__should_return_ok", - # "when_library__should_return_ok", - # "when_library_answer__should_return_ok", ], ) def test_studio(test_file_name: str): @@ -78,46 +50,16 @@ def test_studio(test_file_name: str): @pytest.mark.parametrize( argnames=["test_file_name"], argvalues=[ - ("async_answer.py",), ("async_chat.py",), - ("async_completion.py",), - ("async_embed.py",), - ("async_gec.py",), - ("async_improvements.py",), - ("async_paraphrase.py",), - ("async_segmentation.py",), - ("async_summarize.py",), - ("async_summarize_by_segment.py",), - # ("async_tokenization.py",), ("chat/async_chat_completions.py",), ("chat/async_stream_chat_completions.py",), - # ("async_custom_model.py", ), - # ("async_custom_model_completion.py", ), - # ("async_dataset.py", ), - # ("async_library.py", ), - # ("async_library_answer.py", ), ("conversational_rag/conversational_rag.py",), ("conversational_rag/async_conversational_rag.py",), ], ids=[ - "when_answer__should_return_ok", "when_chat__should_return_ok", - "when_completion__should_return_ok", - "when_embed__should_return_ok", - "when_gec__should_return_ok", - "when_improvements__should_return_ok", - "when_paraphrase__should_return_ok", - "when_segmentation__should_return_ok", - "when_summarize__should_return_ok", - "when_summarize_by_segment__should_return_ok", - # "when_tokenization__should_return_ok", "when_chat_completions__should_return_ok", "when_stream_chat_completions__should_return_ok", - # "when_custom_model__should_return_ok", - # "when_custom_model_completion__should_return_ok", - # "when_dataset__should_return_ok", - # "when_library__should_return_ok", - # "when_library_answer__should_return_ok", "when_conversational_rag__should_return_ok", "when_async_conversational_rag__should_return_ok", ], diff --git a/tests/integration_tests/services/__init__.py b/tests/integration_tests/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration_tests/services/test_sagemaker.py b/tests/integration_tests/services/test_sagemaker.py deleted file mode 100644 index 29e61be2..00000000 --- a/tests/integration_tests/services/test_sagemaker.py +++ /dev/null @@ -1,36 +0,0 @@ -import os - -import pytest - -from ai21 import SageMaker - - -def _add_or_remove_api_key(use_api_key: bool): - if use_api_key: - os.environ["AI21_API_KEY"] = "test" - else: - del os.environ["AI21_API_KEY"] - - -@pytest.mark.parametrize( - argnames="use_api_key", - argvalues=[True, False], - ids=["with_api_key", "without_api_key"], -) -def test_sagemaker__get_model_package_arn(use_api_key: bool): - _add_or_remove_api_key(use_api_key) - model_packages_arn = SageMaker.get_model_package_arn(model_name="j2-mid", region="us-east-1") - assert isinstance(model_packages_arn, str) - assert len(model_packages_arn) > 0 - - -@pytest.mark.parametrize( - argnames="use_api_key", - argvalues=[True, False], - ids=["with_api_key", "without_api_key"], -) -def test_sagemaker__list_model_package_versions(use_api_key: bool): - _add_or_remove_api_key(use_api_key) - model_packages_arn = SageMaker.list_model_package_versions(model_name="j2-mid", region="us-east-1") - assert isinstance(model_packages_arn, list) - assert len(model_packages_arn) > 0 diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index a26d2514..1e8413b5 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -4,39 +4,15 @@ from ai21.clients.studio.resources.chat import AsyncChatCompletions from ai21.clients.studio.resources.chat import ChatCompletions -from ai21.clients.studio.resources.studio_answer import StudioAnswer, AsyncStudioAnswer from ai21.clients.studio.resources.studio_chat import StudioChat, AsyncStudioChat from ai21.clients.studio.resources.studio_completion import StudioCompletion, AsyncStudioCompletion -from ai21.clients.studio.resources.studio_embed import StudioEmbed, AsyncStudioEmbed -from ai21.clients.studio.resources.studio_gec import StudioGEC, AsyncStudioGEC -from ai21.clients.studio.resources.studio_improvements import StudioImprovements, AsyncStudioImprovements -from ai21.clients.studio.resources.studio_paraphrase import StudioParaphrase, AsyncStudioParaphrase -from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation, AsyncStudioSegmentation -from ai21.clients.studio.resources.studio_summarize import StudioSummarize, AsyncStudioSummarize -from ai21.clients.studio.resources.studio_summarize_by_segment import ( - StudioSummarizeBySegment, - AsyncStudioSummarizeBySegment, -) from ai21.http_client.async_http_client import AsyncAI21HTTPClient from ai21.http_client.http_client import AI21HTTPClient from ai21.models import ( - AnswerResponse, ChatMessage, RoleType, ChatResponse, CompletionsResponse, - EmbedType, - EmbedResponse, - GECResponse, - ImprovementType, - ImprovementsResponse, - ParaphraseStyleType, - ParaphraseResponse, - DocumentType, - SegmentationResponse, - SummaryMethod, - SummarizeResponse, - SummarizeBySegmentResponse, ) from ai21.models._pydantic_compatibility import _to_dict, _from_dict from ai21.models.chat import ( @@ -72,25 +48,6 @@ def mock_async_successful_httpx_response(mocker: MockerFixture) -> httpx.Respons return async_mock_httpx_response -def get_studio_answer(is_async: bool = False): - _DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" - _DUMMY_QUESTION = "What is the answer?" - json_response = {"id": "some-id", "answer_in_context": True, "answer": "42"} - resource = AsyncStudioAnswer if is_async else StudioAnswer - - return ( - resource, - {"context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION}, - "answer", - { - "context": _DUMMY_CONTEXT, - "question": _DUMMY_QUESTION, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=AnswerResponse, obj_dict=json_response), - ) - - def get_studio_chat(is_async: bool = False): _DUMMY_MODEL = "dummy-chat-model" _DUMMY_MESSAGES = [ @@ -207,207 +164,3 @@ def get_studio_completion(is_async: bool = True, **kwargs): httpx.Response(status_code=200, json=json_response), _from_dict(obj=CompletionsResponse, obj_dict=json_response), ) - - -def get_studio_embed(is_async: bool = False): - json_response = { - "id": "some-id", - "results": [ - {"embedding": [1.0, 2.0, 3.0]}, - {"embedding": [4.0, 5.0, 6.0]}, - ], - } - - resource = AsyncStudioEmbed if is_async else StudioEmbed - - return ( - resource, - {"texts": ["text0", "text1"], "type": EmbedType.QUERY}, - "embed", - { - "texts": ["text0", "text1"], - "type": EmbedType.QUERY.value, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=EmbedResponse, obj_dict=json_response), - ) - - -def get_studio_gec(is_async: bool = False): - json_response = { - "id": "some-id", - "corrections": [ - { - "suggestion": "text to fix", - "startIndex": 9, - "endIndex": 10, - "originalText": "text to fi", - "correctionType": "Spelling", - } - ], - } - text = "text to fi" - - resource = AsyncStudioGEC if is_async else StudioGEC - - return ( - resource, - {"text": text}, - "gec", - { - "text": text, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=GECResponse, obj_dict=json_response), - ) - - -def get_studio_improvements(is_async: bool = False): - json_response = { - "id": "some-id", - "improvements": [ - { - "suggestions": ["This text is improved"], - "startIndex": 0, - "endIndex": 15, - "originalText": "text to improve", - "improvementType": "FLUENCY", - } - ], - } - text = "text to improve" - types = [ImprovementType.FLUENCY] - - resource = AsyncStudioImprovements if is_async else StudioImprovements - - return ( - resource, - {"text": text, "types": types}, - "improvements", - { - "text": text, - "types": types, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=ImprovementsResponse, obj_dict=json_response), - ) - - -def get_studio_paraphrase(is_async: bool = False): - text = "text to paraphrase" - style = ParaphraseStyleType.CASUAL - start_index = 0 - end_index = 10 - json_response = { - "id": "some-id", - "suggestions": [ - { - "text": "This text is paraphrased", - } - ], - } - - resource = AsyncStudioParaphrase if is_async else StudioParaphrase - - return ( - resource, - {"text": text, "style": style, "start_index": start_index, "end_index": end_index}, - "paraphrase", - { - "text": text, - "style": style, - "startIndex": start_index, - "endIndex": end_index, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=ParaphraseResponse, obj_dict=json_response), - ) - - -def get_studio_segmentation(is_async: bool = False): - source = "segmentation text" - source_type = DocumentType.TEXT - json_response = { - "id": "some-id", - "segments": [ - { - "segmentText": "This text is segmented", - "segmentType": "segment_type", - } - ], - } - - resource = AsyncStudioSegmentation if is_async else StudioSegmentation - - return ( - resource, - {"source": source, "source_type": source_type}, - "segmentation", - { - "source": source, - "sourceType": source_type, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=SegmentationResponse, obj_dict=json_response), - ) - - -def get_studio_summarization(is_async: bool = False): - source = "text to summarize" - source_type = "TEXT" - focus = "text" - summary_method = SummaryMethod.FULL_DOCUMENT - json_response = { - "id": "some-id", - "summary": "This text is summarized", - } - - resource = AsyncStudioSummarize if is_async else StudioSummarize - - return ( - resource, - {"source": source, "source_type": source_type, "focus": focus, "summary_method": summary_method}, - "summarize", - { - "source": source, - "sourceType": source_type, - "focus": focus, - "summaryMethod": summary_method, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=SummarizeResponse, obj_dict=json_response), - ) - - -def get_studio_summarize_by_segment(is_async: bool = False): - source = "text to summarize" - source_type = "TEXT" - focus = "text" - json_response = { - "id": "some-id", - "segments": [ - { - "summary": "This text is summarized", - "segmentText": "This text is segmented", - "segmentHtml": "", - "segmentType": "segment_type", - "hasSummary": True, - "highlights": [], - } - ], - } - - resource = AsyncStudioSummarizeBySegment if is_async else StudioSummarizeBySegment - - return ( - resource, - {"source": source, "source_type": source_type, "focus": focus}, - "summarize-by-segment", - { - "source": source, - "sourceType": source_type, - "focus": focus, - }, - httpx.Response(status_code=200, json=json_response), - _from_dict(obj=SummarizeBySegmentResponse, obj_dict=json_response), - ) diff --git a/tests/unittests/clients/studio/resources/test_async_studio_resource.py b/tests/unittests/clients/studio/resources/test_async_studio_resource.py index 96f8cdff..0897479c 100644 --- a/tests/unittests/clients/studio/resources/test_async_studio_resource.py +++ b/tests/unittests/clients/studio/resources/test_async_studio_resource.py @@ -1,31 +1,17 @@ from typing import TypeVar, Callable -import httpx import pytest -from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient -from ai21.models import AnswerResponse -from ai21.models._pydantic_compatibility import _to_dict from ai21.models.ai21_base_model import AI21BaseModel from tests.unittests.clients.studio.resources.conftest import ( - get_studio_answer, get_studio_chat, get_chat_completions, get_studio_completion, - get_studio_embed, - get_studio_gec, - get_studio_improvements, - get_studio_paraphrase, - get_studio_segmentation, - get_studio_summarization, - get_studio_summarize_by_segment, ) _BASE_URL = "https://test.api.ai21.com/studio/v1" -_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" -_DUMMY_QUESTION = "What is the answer?" T = TypeVar("T", bound=AsyncStudioResource) @@ -34,18 +20,10 @@ class TestAsyncStudioResources: @pytest.mark.asyncio @pytest.mark.parametrize( ids=[ - "async_studio_answer", "async_studio_chat", "async_chat_completions", "async_studio_completion", "async_studio_completion_with_extra_args", - "async_studio_embed", - "async_studio_gec", - "async_studio_improvements", - "async_studio_paraphrase", - "async_studio_segmentation", - "async_studio_summarization", - "async_studio_summarize_by_segment", ], argnames=[ "studio_resource", @@ -56,18 +34,10 @@ class TestAsyncStudioResources: "expected_response", ], argvalues=[ - (get_studio_answer(is_async=True)), (get_studio_chat(is_async=True)), (get_chat_completions(is_async=True)), (get_studio_completion(is_async=True)), (get_studio_completion(is_async=True, temperature=0.5, max_tokens=50)), - (get_studio_embed(is_async=True)), - (get_studio_gec(is_async=True)), - (get_studio_improvements(is_async=True)), - (get_studio_paraphrase(is_async=True)), - (get_studio_segmentation(is_async=True)), - (get_studio_summarization(is_async=True)), - (get_studio_summarize_by_segment(is_async=True)), ], ) async def test__create__should_return_response( @@ -97,34 +67,3 @@ async def test__create__should_return_response( stream=False, files=None, ) - - @pytest.mark.asyncio - async def test__create__when_pass_kwargs__should_pass_to_request( - self, - mock_async_ai21_studio_client: AsyncAI21HTTPClient, - mock_async_successful_httpx_response: httpx.Response, - ): - expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") - mock_async_successful_httpx_response.json.return_value = _to_dict(expected_answer) - - mock_async_ai21_studio_client.execute_http_request.return_value = mock_async_successful_httpx_response - studio_answer = AsyncStudioAnswer(mock_async_ai21_studio_client) - - await studio_answer.create( - context=_DUMMY_CONTEXT, - question=_DUMMY_QUESTION, - some_dummy_kwargs="some_dummy_value", - ) - - mock_async_ai21_studio_client.execute_http_request.assert_called_with( - method="POST", - path="/answer", - body={ - "context": _DUMMY_CONTEXT, - "question": _DUMMY_QUESTION, - "some_dummy_kwargs": "some_dummy_value", - }, - params={}, - stream=False, - files=None, - ) diff --git a/tests/unittests/clients/studio/resources/test_completion.py b/tests/unittests/clients/studio/resources/test_completion.py deleted file mode 100644 index 59e401f9..00000000 --- a/tests/unittests/clients/studio/resources/test_completion.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest -from ai21 import AI21Client - - -def test__when_model_and_model_id__raise_error(): - client = AI21Client() - with pytest.raises(ValueError): - client.completion.create( - model="j2-ultra", - model_id="j2-ultra", - prompt="test prompt", - ) diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index 39f6833d..8f3cdd7c 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -1,49 +1,26 @@ from typing import TypeVar, Callable import pytest -import httpx from ai21.http_client.http_client import AI21HTTPClient -from ai21.clients.studio.resources.studio_answer import StudioAnswer from ai21.clients.studio.resources.studio_resource import StudioResource -from ai21.models import AnswerResponse -from ai21.models._pydantic_compatibility import _to_dict from ai21.models.ai21_base_model import AI21BaseModel from tests.unittests.clients.studio.resources.conftest import ( - get_studio_answer, get_studio_chat, get_studio_completion, - get_studio_embed, - get_studio_gec, - get_studio_improvements, - get_studio_paraphrase, - get_studio_segmentation, - get_studio_summarization, - get_studio_summarize_by_segment, get_chat_completions, ) _BASE_URL = "https://test.api.ai21.com/studio/v1" -_DUMMY_CONTEXT = "What is the answer to life, the universe and everything?" -_DUMMY_QUESTION = "What is the answer?" - T = TypeVar("T", bound=StudioResource) class TestStudioResources: @pytest.mark.parametrize( ids=[ - "studio_answer", "studio_chat", "chat_completions", "studio_completion", "studio_completion_with_extra_args", - "studio_embed", - "studio_gec", - "studio_improvements", - "studio_paraphrase", - "studio_segmentation", - "studio_summarization", - "studio_summarize_by_segment", ], argnames=[ "studio_resource", @@ -54,18 +31,10 @@ class TestStudioResources: "expected_response", ], argvalues=[ - (get_studio_answer()), (get_studio_chat()), (get_chat_completions()), (get_studio_completion(is_async=False)), (get_studio_completion(is_async=False, temperature=0.5, max_tokens=50)), - (get_studio_embed()), - (get_studio_gec()), - (get_studio_improvements()), - (get_studio_paraphrase()), - (get_studio_segmentation()), - (get_studio_summarization()), - (get_studio_summarize_by_segment()), ], ) def test__create__should_return_response( @@ -95,33 +64,3 @@ def test__create__should_return_response( stream=False, files=None, ) - - def test__create__when_pass_kwargs__should_pass_to_request( - self, - mock_ai21_studio_client: AI21HTTPClient, - mock_successful_httpx_response: httpx.Response, - ): - expected_answer = AnswerResponse(id="some-id", answer_in_context=True, answer="42") - mock_successful_httpx_response.json.return_value = _to_dict(expected_answer) - - mock_ai21_studio_client.execute_http_request.return_value = mock_successful_httpx_response - studio_answer = StudioAnswer(mock_ai21_studio_client) - - studio_answer.create( - context=_DUMMY_CONTEXT, - question=_DUMMY_QUESTION, - some_dummy_kwargs="some_dummy_value", - ) - - mock_ai21_studio_client.execute_http_request.assert_called_with( - method="POST", - path="/answer", - body={ - "context": _DUMMY_CONTEXT, - "question": _DUMMY_QUESTION, - "some_dummy_kwargs": "some_dummy_value", - }, - params={}, - stream=False, - files=None, - ) diff --git a/tests/unittests/models/response_mocks.py b/tests/unittests/models/response_mocks.py index 4a7b2245..ffb55ad2 100644 --- a/tests/unittests/models/response_mocks.py +++ b/tests/unittests/models/response_mocks.py @@ -1,5 +1,4 @@ from ai21.models import ( - AnswerResponse, ChatResponse, ChatOutput, RoleType, @@ -9,42 +8,12 @@ Completion, CompletionFinishReason, CompletionData, - EmbedResponse, - EmbedResult, - GECResponse, - Correction, - ImprovementsResponse, - Improvement, - ParaphraseResponse, - Suggestion, - SegmentationResponse, - SummarizeResponse, - SummarizeBySegmentResponse, - SegmentSummary, - Highlight, ) from ai21.models.chat import ChatCompletionResponse, ChatCompletionResponseChoice from ai21.models.chat.chat_message import AssistantMessage -from ai21.models.responses.segmentation_response import Segment from ai21.models.usage_info import UsageInfo -def get_answer_response__answer_in_context_not_none(): - expected_dict = {"id": "123", "answerInContext": True, "answer": "Koalas eat the leaves of Eucalyptus trees."} - answer_response = AnswerResponse( - id="123", answer_in_context=True, answer="Koalas eat the leaves of Eucalyptus trees." - ) - - return answer_response, expected_dict, AnswerResponse - - -def get_answer_response__answer_in_context_is_none(): - expected_dict = {"id": "123", "answerInContext": False, "answer": None} - answer_response = AnswerResponse(id="123", answer_in_context=False) - - return answer_response, expected_dict, AnswerResponse - - def get_chat_response(): expected_dict = { "outputs": [ @@ -181,153 +150,3 @@ def get_completions_response(): completion_response = CompletionsResponse(id="123-abc", prompt=prompt, completions=[completion]) return completion_response, expected_dict, CompletionsResponse - - -def get_embed_response(): - expected_dict = { - "id": "123", - "results": [ - { - "embedding": [ - 0.03452427685260773, - -0.0011991093633696437, - ] - } - ], - } - - embed_response = EmbedResponse( - id="123", results=[EmbedResult(embedding=[0.03452427685260773, -0.0011991093633696437])] - ) - - return embed_response, expected_dict, EmbedResponse - - -def get_gec_response(): - expected_dict = { - "id": "123", - "corrections": [ - { - "suggestion": "love rock", - "startIndex": 2, - "endIndex": 9, - "originalText": "luv rok", - "correctionType": "Spelling", - } - ], - } - - gec_response = GECResponse( - id="123", - corrections=[ - Correction( - suggestion="love rock", start_index=2, end_index=9, original_text="luv rok", correction_type="Spelling" - ) - ], - ) - - return gec_response, expected_dict, GECResponse - - -def get_improvements_response(): - expected_dict = { - "id": "123", - "improvements": [ - { - "suggestions": ["technical", "practical", "analytical"], - "startIndex": 104, - "endIndex": 108, - "originalText": "hard", - "improvementType": "vocabulary/specificity", - }, - ], - } - - improvements_response = ImprovementsResponse( - id="123", - improvements=[ - Improvement( - suggestions=["technical", "practical", "analytical"], - start_index=104, - end_index=108, - original_text="hard", - improvement_type="vocabulary/specificity", - ) - ], - ) - - return improvements_response, expected_dict, ImprovementsResponse - - -def get_paraphrase_response(): - expected_dict = { - "id": "123", - "suggestions": [ - {"text": "Thank you so much for the gift I received on Monday."}, - ], - } - - paraphrase_response = ParaphraseResponse( - id="123", suggestions=[Suggestion(text="Thank you so much for the gift I received on Monday.")] - ) - - return paraphrase_response, expected_dict, ParaphraseResponse - - -def get_segmentation_response(): - expected_dict = { - "id": "123", - "segments": [ - {"segmentText": "Further reading", "segmentType": "h2"}, - ], - } - - segmentation_response = SegmentationResponse( - id="123", segments=[Segment(segment_text="Further reading", segment_type="h2")] - ) - - return segmentation_response, expected_dict, SegmentationResponse - - -def get_summarize_response(): - expected_dict = { - "id": "123", - "summary": "The blue whale is a marine mammal that lives off California's coast.", - } - - summarization_response = SummarizeResponse( - id="123", - summary="The blue whale is a marine mammal that lives off California's coast.", - ) - - return summarization_response, expected_dict, SummarizeResponse - - -def get_summarize_by_segment_response(): - expected_dict = { - "id": "123", - "segments": [ - { - "summary": "The blue whale is the largest animal known ever to have existed.", - "segmentType": "normal_text", - "hasSummary": True, - "highlights": [{"text": "The blue whale", "startIndex": 0, "endIndex": 14}], - "segmentHtml": None, - "segmentText": None, - }, - ], - } - - summarization_response = SummarizeBySegmentResponse( - id="123", - segments=[ - SegmentSummary( - summary="The blue whale is the largest animal known ever to have existed.", - segment_type="normal_text", - has_summary=True, - highlights=[Highlight(text="The blue whale", start_index=0, end_index=14)], - ), - ], - ) - - return summarization_response, expected_dict, SummarizeBySegmentResponse diff --git a/tests/unittests/models/test_serialization.py b/tests/unittests/models/test_serialization.py index 1b5d3e82..35473cd0 100644 --- a/tests/unittests/models/test_serialization.py +++ b/tests/unittests/models/test_serialization.py @@ -6,18 +6,9 @@ from ai21.models._pydantic_compatibility import _to_dict, _from_dict from ai21.models.ai21_base_model import IS_PYDANTIC_V2, AI21BaseModel from tests.unittests.models.response_mocks import ( - get_answer_response__answer_in_context_not_none, - get_answer_response__answer_in_context_is_none, get_chat_response, get_chat_completions_response, get_completions_response, - get_embed_response, - get_gec_response, - get_improvements_response, - get_paraphrase_response, - get_segmentation_response, - get_summarize_response, - get_summarize_by_segment_response, ) @@ -50,18 +41,9 @@ def test_penalty__from_json__should_return_instance_with_given_values(): @pytest.mark.parametrize( ids=[ - "answer_response__answer_in_context", - "answer_response__answer_not_in_context", "chat_response", "chat_completions_response", "completion_response", - "embed_response", - "gec_response", - "improvements_response", - "paraphrase_response", - "segmentation_response", - "summarization_response", - "summarize_by_segment_response", ], argnames=[ "response_obj", @@ -69,18 +51,9 @@ def test_penalty__from_json__should_return_instance_with_given_values(): "response_cls", ], argvalues=[ - (get_answer_response__answer_in_context_not_none()), - (get_answer_response__answer_in_context_is_none()), (get_chat_response()), (get_chat_completions_response()), (get_completions_response()), - (get_embed_response()), - (get_gec_response()), - (get_improvements_response()), - (get_paraphrase_response()), - (get_segmentation_response()), - (get_summarize_response()), - (get_summarize_by_segment_response()), ], ) def test_to_dict__should_serialize_to_dict__( @@ -92,18 +65,9 @@ def test_to_dict__should_serialize_to_dict__( @pytest.mark.parametrize( ids=[ - "answer_response__answer_in_context_not_none", - "answer_response__answer_in_context_is_none", "chat_response", "chat_completions_response", "completion_response", - "embed_response", - "gec_response", - "improvements_response", - "paraphrase_response", - "segmentation_response", - "summarization_response", - "summarize_by_segment_response", ], argnames=[ "response_obj", @@ -111,18 +75,9 @@ def test_to_dict__should_serialize_to_dict__( "response_cls", ], argvalues=[ - (get_answer_response__answer_in_context_not_none()), - (get_answer_response__answer_in_context_is_none()), (get_chat_response()), (get_chat_completions_response()), (get_completions_response()), - (get_embed_response()), - (get_gec_response()), - (get_improvements_response()), - (get_paraphrase_response()), - (get_segmentation_response()), - (get_summarize_response()), - (get_summarize_by_segment_response()), ], ) def test_from_dict__should_serialize_from_dict__( diff --git a/tests/unittests/services/__init__.py b/tests/unittests/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unittests/services/sagemaker_stub.py b/tests/unittests/services/sagemaker_stub.py deleted file mode 100644 index 18cc01fa..00000000 --- a/tests/unittests/services/sagemaker_stub.py +++ /dev/null @@ -1,12 +0,0 @@ -from unittest.mock import Mock - -from ai21 import SageMaker -from ai21.http_client.http_client import AI21HTTPClient - - -class SageMakerStub(SageMaker): - ai21_http_client = Mock(spec=AI21HTTPClient) - - @classmethod - def _create_ai21_http_client(cls, path: str) -> AI21HTTPClient: - return cls.ai21_http_client diff --git a/tests/unittests/services/test_sagemaker.py b/tests/unittests/services/test_sagemaker.py deleted file mode 100644 index 1e8fdc04..00000000 --- a/tests/unittests/services/test_sagemaker.py +++ /dev/null @@ -1,46 +0,0 @@ -import httpx -import pytest - -from ai21 import ModelPackageDoesntExistError -from tests.unittests.services.sagemaker_stub import SageMakerStub - -_DUMMY_ARN = "some-model-package-id1" -_DUMMY_VERSIONS = ["1.0.0", "1.0.1"] - - -class TestSageMakerService: - def test__get_model_package_arn__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): - mock_httpx_response.json.return_value = { - "arn": _DUMMY_ARN, - "versions": _DUMMY_VERSIONS, - } - SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response - - actual_model_package_arn = SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - - assert actual_model_package_arn == _DUMMY_ARN - - def test__get_model_package_arn__when_no_arn__should_raise_error(self, mock_httpx_response: httpx.Response): - mock_httpx_response.json.return_value = {"arn": []} - SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response - - with pytest.raises(ModelPackageDoesntExistError): - SageMakerStub.get_model_package_arn(model_name="j2-mid", region="us-east-1") - - def test__list_model_package_versions__should_return_model_package_arn(self, mock_httpx_response: httpx.Response): - mock_httpx_response.json.return_value = { - "versions": _DUMMY_VERSIONS, - } - SageMakerStub.ai21_http_client.execute_http_request.return_value = mock_httpx_response - - actual_model_package_arn = SageMakerStub.list_model_package_versions(model_name="j2-mid", region="us-east-1") - - assert actual_model_package_arn == _DUMMY_VERSIONS - - def test__list_model_package_versions__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistError): - SageMakerStub.list_model_package_versions(model_name="openai", region="us-east-1") - - def test__get_model_package_arn__when_model_package_not_available__should_raise_an_error(self): - with pytest.raises(ModelPackageDoesntExistError): - SageMakerStub.get_model_package_arn(model_name="openai", region="us-east-1") diff --git a/tests/unittests/test_imports.py b/tests/unittests/test_imports.py index f4420575..7b6359f4 100644 --- a/tests/unittests/test_imports.py +++ b/tests/unittests/test_imports.py @@ -4,23 +4,20 @@ from ai21 import __all__ EXPECTED_ALL = [ - "AI21APIError", - "AI21AzureClient", - "AI21BedrockClient", - "AI21Client", "AI21EnvConfig", - "AI21Error", - "AI21SageMakerClient", - "APITimeoutError", - "AsyncAI21AzureClient", + "AI21Client", "AsyncAI21Client", - "BedrockModelID", + "AI21APIError", + "APITimeoutError", + "AI21Error", "MissingApiKeyError", "ModelPackageDoesntExistError", - "SageMaker", "TooManyRequestsError", + "AI21BedrockClient", + "BedrockModelID", + "AI21AzureClient", + "AsyncAI21AzureClient", "AsyncAI21BedrockClient", - "AsyncAI21SageMakerClient", "AI21VertexClient", "AsyncAI21VertexClient", ]