Skip to content

Commit

Permalink
fix: imports
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Dec 12, 2023
1 parent ac14452 commit 11f45f7
Show file tree
Hide file tree
Showing 13 changed files with 45 additions and 93 deletions.
10 changes: 4 additions & 6 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient
from .clients.bedrock.bedrock_model_id import BedrockModelID
from .clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient
from .clients.studio.ai21_client import AI21Client
from .version import VERSION
from .clients import AI21Client, AI21BedrockClient, AI21SageMakerClient

__version__ = VERSION

__all__ = ["AI21Client", "AI21BedrockClient", "AI21SageMakerClient", "BedrockModelID"]


class BedrockModelID:
J2_MID_V1 = "ai21.j2-mid-v1"
J2_ULTRA_V1 = "ai21.j2-ultra-v1"
5 changes: 0 additions & 5 deletions ai21/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .sagemaker import AI21SageMakerClient
from .bedrock import AI21BedrockClient
from .studio import AI21Client

__all__ = ["AI21Client", "AI21BedrockClient", "AI21SageMakerClient"]
3 changes: 0 additions & 3 deletions ai21/clients/bedrock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .ai21_bedrock_client import AI21BedrockClient

__all__ = ["AI21BedrockClient"]
10 changes: 2 additions & 8 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
from botocore.exceptions import ClientError

from ai21.ai21_env_config import AI21EnvConfig, _AI21EnvConfig
from ai21.clients.bedrock import resources
from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion
from ai21.errors import AccessDenied, NotFound, APITimeoutError
from ai21.http_client import handle_non_success_response
from ai21.utils import log_error

__all__ = [
"resources",
"AI21BedrockClient",
]


RUNTIME_NAME = "bedrock-runtime"
_ERROR_MSG_TEMPLATE = (
r"Received client error \((.*?)\) from primary with message \"(.*?)\". "
Expand All @@ -37,7 +31,7 @@ def __init__(
self._session = (
session.client(RUNTIME_NAME) if session else boto3.client(RUNTIME_NAME, region_name=env_config.aws_region)
)
self.completion = resources.BedrockCompletion(self)
self.completion = BedrockCompletion(self)

def invoke_model(self, model_id: str, input_json: str) -> Dict[str, Any]:
try:
Expand Down
3 changes: 3 additions & 0 deletions ai21/clients/bedrock/bedrock_model_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class BedrockModelID:
J2_MID_V1 = "ai21.j2-mid-v1"
J2_ULTRA_V1 = "ai21.j2-ultra-v1"
3 changes: 0 additions & 3 deletions ai21/clients/bedrock/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .bedrock_completion import BedrockCompletion

__all__ = ["BedrockCompletion"]
3 changes: 0 additions & 3 deletions ai21/clients/sagemaker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .ai21_sagemaker_client import AI21SageMakerClient

__all__ = ["AI21SageMakerClient"]
18 changes: 10 additions & 8 deletions ai21/clients/sagemaker/ai21_sagemaker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_studio_client import AI21StudioClient
from ai21.clients.sagemaker import resources
from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer
from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion
from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC
from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase
from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize
from ai21.errors import BadRequest, ServiceUnavailable, ServerError, APIError
from ai21.http_client import handle_non_success_response
from ai21.utils import log_error

__all__ = ["resources", "AI21SageMakerClient"]

# Each one of the clients should be able to implement async/sync interface
_ERROR_MSG_TEMPLATE = (
r"Received client error \((.*?)\) from primary with message \"(.*?)\". "
Expand Down Expand Up @@ -56,11 +58,11 @@ def __init__(
)
self._region = region or self._env_config.aws_region
self._endpoint_name = endpoint_name
self.completion = resources.SageMakerCompletion(self)
self.paraphrase = resources.SageMakerParaphrase(self)
self.answer = resources.SageMakerAnswer(self)
self.gec = resources.SageMakerGEC(self)
self.summarize = resources.SageMakerSummarize(self)
self.completion = SageMakerCompletion(self)
self.paraphrase = SageMakerParaphrase(self)
self.answer = SageMakerAnswer(self)
self.gec = SageMakerGEC(self)
self.summarize = SageMakerSummarize(self)

def invoke_endpoint(
self,
Expand Down
2 changes: 0 additions & 2 deletions ai21/clients/sagemaker/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
SAGEMAKER_ENDPOINT_KEY = "sm_endpoint"

SAGEMAKER_MODEL_PACKAGE_NAMES = [
"j2-light",
"j2-mid",
Expand Down
7 changes: 0 additions & 7 deletions ai21/clients/sagemaker/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
from .sagemaker_completion import SageMakerCompletion
from .sagemaker_gec import SageMakerGEC
from .sagemaker_paraphrase import SageMakerParaphrase
from .sagemaker_summarize import SageMakerSummarize
from .sagemaker_answer import SageMakerAnswer

__all__ = ["SageMakerSummarize", "SageMakerParaphrase", "SageMakerGEC", "SageMakerCompletion", "SageMakerAnswer"]
3 changes: 0 additions & 3 deletions ai21/clients/studio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .ai21_client import AI21Client

__all__ = ["AI21Client"]
42 changes: 26 additions & 16 deletions ai21/clients/studio/ai21_client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from typing import Optional, Any, Dict

from ai21.ai21_studio_client import AI21StudioClient
from ai21.clients.studio import resources
from ai21.clients.studio.resources.studio_answer import StudioAnswer
from ai21.clients.studio.resources.studio_chat import StudioChat
from ai21.clients.studio.resources.studio_completion import StudioCompletion
from ai21.clients.studio.resources.studio_custom_model import StudioCustomModel
from ai21.clients.studio.resources.studio_dataset import StudioDataset
from ai21.clients.studio.resources.studio_embed import StudioEmbed
from ai21.clients.studio.resources.studio_gec import StudioGEC
from ai21.clients.studio.resources.studio_improvements import StudioImprovements
from ai21.clients.studio.resources.studio_library import StudioLibrary
from ai21.clients.studio.resources.studio_paraphrase import StudioParaphrase
from ai21.clients.studio.resources.studio_segmentation import StudioSegmentation
from ai21.clients.studio.resources.studio_summarize import StudioSummarize
from ai21.clients.studio.resources.studio_summarize_by_segment import StudioSummarizeBySegment
from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer
from ai21.tokenizers.factory import get_tokenizer

__all__ = ["AI21Client", "resources"]


class AI21Client(AI21StudioClient):
"""
Expand All @@ -33,19 +43,19 @@ def __init__(
timeout_sec=timeout_sec,
num_retries=num_retries,
)
self.completion = resources.StudioCompletion(self)
self.chat = resources.StudioChat(self)
self.summarize = resources.StudioSummarize(self)
self.embed = resources.StudioEmbed(self)
self.gec = resources.StudioGEC(self)
self.improvements = resources.StudioImprovements(self)
self.paraphrase = resources.StudioParaphrase(self)
self.summarize_by_segment = resources.StudioSummarizeBySegment(self)
self.custom_model = resources.StudioCustomModel(self)
self.dataset = resources.StudioDataset(self)
self.answer = resources.StudioAnswer(self)
self.library = resources.StudioLibrary(self)
self.segmentation = resources.StudioSegmentation(self)
self.completion = StudioCompletion(self)
self.chat = StudioChat(self)
self.summarize = StudioSummarize(self)
self.embed = StudioEmbed(self)
self.gec = StudioGEC(self)
self.improvements = StudioImprovements(self)
self.paraphrase = StudioParaphrase(self)
self.summarize_by_segment = StudioSummarizeBySegment(self)
self.custom_model = StudioCustomModel(self)
self.dataset = StudioDataset(self)
self.answer = StudioAnswer(self)
self.library = StudioLibrary(self)
self.segmentation = StudioSegmentation(self)

def count_token(self, text: str, model_id: str = "j2-instruct") -> int:
# We might want to cache the tokenizer instance within the class
Expand Down
29 changes: 0 additions & 29 deletions ai21/clients/studio/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +0,0 @@
from .studio_answer import StudioAnswer
from .studio_chat import StudioChat
from .studio_completion import StudioCompletion
from .studio_custom_model import StudioCustomModel
from .studio_dataset import StudioDataset
from .studio_embed import StudioEmbed
from .studio_gec import StudioGEC
from .studio_improvements import StudioImprovements
from .studio_library import StudioLibrary
from .studio_paraphrase import StudioParaphrase
from .studio_segmentation import StudioSegmentation
from .studio_summarize import StudioSummarize
from .studio_summarize_by_segment import StudioSummarizeBySegment

__all__ = [
"StudioSummarizeBySegment",
"StudioSummarize",
"StudioSegmentation",
"StudioParaphrase",
"StudioLibrary",
"StudioImprovements",
"StudioGEC",
"StudioEmbed",
"StudioDataset",
"StudioCustomModel",
"StudioCompletion",
"StudioChat",
"StudioAnswer",
]

0 comments on commit 11f45f7

Please sign in to comment.