Skip to content

Commit

Permalink
fix: Integration tests (#41)
Browse files Browse the repository at this point in the history
* fix: types

* test: Added some integration tests

* test: improvements

* test: test_paraphrase.py

* fix: doc

* fix: removed unused comment

* test: test_summarize.py

* test: Added tests for test_summarize_by_segment.py

* test: test_segmentation.py

* fix: file id in library response

* fix: example for library

* ci: Add rc branch prefix trigger for integration tests (#43)

* ci: rc branch trigger for integration test

* fix: wrapped in quotes

* fix: types

* test: Added some integration tests

* test: improvements

* test: test_paraphrase.py

* fix: doc

* fix: removed unused comment

* test: test_summarize.py

* test: Added tests for test_summarize_by_segment.py

* test: test_segmentation.py

* fix: file id in library response

* fix: example for library

* docs: docstrings

* fix: question

* fix: CR

* test: Added tests to segment type in embed
  • Loading branch information
asafgardin authored Jan 30, 2024
1 parent 127cef4 commit 78709a7
Show file tree
Hide file tree
Showing 28 changed files with 692 additions and 6 deletions.
11 changes: 10 additions & 1 deletion ai21/clients/common/answer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def create(
mode: Optional[Mode] = None,
**kwargs,
) -> AnswerResponse:
"""
:param context: A string containing the document context for which the question will be answered
:param question: A string containing the question to be answered based on the provided context.
:param answer_length: Approximate length of the answer in words.
:param mode:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> AnswerResponse:
Expand All @@ -26,7 +35,7 @@ def _create_body(
self,
context: str,
question: str,
answer_length: Optional[str],
answer_length: Optional[AnswerLength],
mode: Optional[str],
) -> Dict[str, Any]:
return {"context": context, "question": question, "answerLength": answer_length, "mode": mode}
19 changes: 19 additions & 0 deletions ai21/clients/common/chat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ def create(
count_penalty: Optional[Penalty] = None,
**kwargs,
) -> ChatResponse:
"""
:param model: model type you wish to interact with
:param messages: A sequence of messages ingested by the model, which then returns the assistant's response
:param system: Offers the model overarching guidance on its response approach, encapsulating context, tone,
guardrails, and more
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ChatResponse:
Expand Down
18 changes: 18 additions & 0 deletions ai21/clients/common/completion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ def create(
epoch: Optional[int] = None,
**kwargs,
) -> CompletionsResponse:
"""
:param model: model type you wish to interact with
:param prompt: Text for model to complete
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param custom_model:
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param epoch:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> CompletionsResponse:
Expand Down
10 changes: 10 additions & 0 deletions ai21/clients/common/custom_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def create(
num_epochs: Optional[int] = None,
**kwargs,
) -> None:
"""
:param dataset_id: The dataset you want to train your model on.
:param model_name: The name of your trained model
:param model_type: The type of model to train.
:param learning_rate: The learning rate used for training.
:param num_epochs: Number of epochs for training
:param kwargs:
:return:
"""
pass

@abstractmethod
Expand Down
11 changes: 11 additions & 0 deletions ai21/clients/common/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ def create(
split_ratio: Optional[float] = None,
**kwargs,
):
"""
:param file_path: Local path to dataset
:param dataset_name: Dataset name. Must be unique
:param selected_columns: Mapping of the columns in the dataset file to prompt and completion columns.
:param approve_whitespace_correction: Automatically correct examples that violate best practices
:param delete_long_rows: Allow removal of examples where prompt + completion lengths exceeds 2047 tokens
:param split_ratio:
:param kwargs:
:return:
"""
pass

@abstractmethod
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/embed_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ class Embed(ABC):

@abstractmethod
def create(self, texts: List[str], *, type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
"""
:param texts: A list of strings, each representing a document or segment of text to be embedded.
:param type: For retrieval/search use cases, indicates whether the texts that were
sent are segments or the query.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> EmbedResponse:
Expand Down
6 changes: 6 additions & 0 deletions ai21/clients/common/gec_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ class GEC(ABC):

@abstractmethod
def create(self, text: str, **kwargs) -> GECResponse:
"""
:param text: The input text to be corrected.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> GECResponse:
Expand Down
7 changes: 7 additions & 0 deletions ai21/clients/common/improvements_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class Improvements(ABC):

@abstractmethod
def create(self, text: str, types: List[ImprovementType], **kwargs) -> ImprovementsResponse:
"""
:param text: The input text to be improved.
:param types: Types of improvements to apply.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ImprovementsResponse:
Expand Down
10 changes: 10 additions & 0 deletions ai21/clients/common/paraphrase_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ def create(
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
"""
:param text: The input text to be paraphrased.
:param style: Controls length and tone
:param start_index: Specifies the starting position of the paraphrasing process in the given text
:param end_index: specifies the position of the last character to be paraphrased, including the character
following it. If the parameter is not provided, the default value is set to the length of the given text.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> ParaphraseResponse:
Expand Down
7 changes: 7 additions & 0 deletions ai21/clients/common/segmentation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ class Segmentation(ABC):

@abstractmethod
def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse:
"""
:param source: Raw input text, or URL of a web page.
:param source_type: The type of the source - either TEXT or URL.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SegmentationResponse:
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/summarize_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def create(
summary_method: Optional[SummaryMethod] = None,
**kwargs,
) -> SummarizeResponse:
"""
:param source: The input text, or URL of a web page to be summarized.
:param source_type: Either TEXT or URL
:param focus: Summaries focused on a topic of your choice.
:param summary_method:
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SummarizeResponse:
Expand Down
8 changes: 8 additions & 0 deletions ai21/clients/common/summarize_by_segment_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def create(
focus: Optional[str] = None,
**kwargs,
) -> SummarizeBySegmentResponse:
"""
:param source: The input text, or URL of a web page to be summarized.
:param source_type: Either TEXT or URL
:param focus: Summaries focused on a topic of your choice.
:param kwargs:
:return:
"""
pass

def _json_to_response(self, json: Dict[str, Any]) -> SummarizeBySegmentResponse:
Expand Down
3 changes: 2 additions & 1 deletion ai21/clients/studio/resources/studio_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from ai21.clients.common.embed_base import Embed
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.models.embed_type import EmbedType
from ai21.models.responses.embed_response import EmbedResponse


class StudioEmbed(StudioResource, Embed):
def create(self, texts: List[str], type: Optional[str] = None, **kwargs) -> EmbedResponse:
def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse:
url = f"{self._client.get_base_url()}/{self._module_name}"
body = self._create_body(texts=texts, type=type)
response = self._post(url=url, body=body)
Expand Down
1 change: 0 additions & 1 deletion ai21/clients/studio/resources/studio_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def create(
summary_method: Optional[SummaryMethod] = None,
**kwargs,
) -> SummarizeResponse:
# Make a summarize request to the AI21 API. Returns the response either as a string or a AI21Summarize object.
body = self._create_body(
source=source,
source_type=source_type,
Expand Down
2 changes: 1 addition & 1 deletion ai21/models/responses/library_answer_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclass
class SourceDocument(AI21BaseModelMixin):
field_id: str
file_id: str
name: str
highlights: List[str]
public_url: Optional[str] = None
Expand Down
8 changes: 7 additions & 1 deletion examples/studio/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ def validate_file_deleted():
file_path = os.getcwd()

path = os.path.join(file_path, file_name)
file_utils.create_file(file_path, file_name, content="test content" * 100)
_SOURCE_TEXT = """Holland is a geographical region and former province on the western coast of the
Netherlands. From the 10th to the 16th century, Holland proper was a unified political
region within the Holy Roman Empire as a county ruled by the counts of Holland.
By the 17th century, the province of Holland had risen to become a maritime and economic power,
dominating the other provinces of the newly independent Dutch Republic."""
file_utils.create_file(file_path, file_name, content=_SOURCE_TEXT)

file_id = client.library.files.create(
file_path=path,
Expand All @@ -31,6 +36,7 @@ def validate_file_deleted():
public_url="www.example.com",
)
print(file_id)

files = client.library.files.list()
print(files)
uploaded_file = client.library.files.get(file_id)
Expand Down
2 changes: 1 addition & 1 deletion examples/studio/library_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@


client = AI21Client()
response = client.library.answer.create(question="Where is Thailand?")
response = client.library.answer.create(question="Can you tell me something about Holland?")
print(response)
Empty file.
38 changes: 38 additions & 0 deletions tests/integration_tests/clients/studio/test_answer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
from ai21 import AI21Client
from ai21.models import AnswerLength, Mode

_CONTEXT = (
"Holland is a geographical region[2] and former province on the western coast of"
" the Netherlands. From the "
"10th to the 16th century, Holland proper was a unified political region within the Holy Roman Empire as a county "
"ruled by the counts of Holland. By the 17th century, the province of Holland had risen to become a maritime and "
"economic power, dominating the other provinces of the newly independent Dutch Republic."
)


@pytest.mark.parametrize(
ids=[
"when_answer_is_in_context",
"when_answer_not_in_context",
],
argnames=["question", "is_answer_in_context", "expected_answer_type"],
argvalues=[
("When did Holland become an economic power?", True, str),
("Is the ocean blue?", False, None),
],
)
def test_answer(question: str, is_answer_in_context: bool, expected_answer_type: type):
client = AI21Client()
response = client.answer.create(
context=_CONTEXT,
question=question,
answer_length=AnswerLength.LONG,
mode=Mode.FLEXIBLE,
)

assert response.answer_in_context == is_answer_in_context
if is_answer_in_context:
assert isinstance(response.answer, str)
else:
assert response.answer is None
Loading

0 comments on commit 78709a7

Please sign in to comment.