-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: Added code * feat: Added setup * feat: more sdk code * feat: poetry * feat: poetry setup * ci: actions * fix: removed unused example * feat: added boto3 * feat: added dependencies * test: added pytest dependency * fix: python version * fix: update python version in lock * fix: format * fix: examples * test: removed example script for studio and added integration test instead * test: bedrock integration test * test: moved examples * ci: fixed inv * fix: lint * feat: version in init * fix: long content * fix: poetry version * fix: added __all__ * fix: Added code to __all__ * fix: prompt * fix: test action * fix: Added shebang * fix: long line * fix: loaded env for tests * fix: Added env * test: only 3.10 * test: default region * test: Added 3.8 * fix: subscriptable type * test: sagemaker tests * fix: used _http methods * fix: default values * ci: removed -vv flag * fix: imports * test: Added conditional skip * fix: CR fixes * fix: boto3 to pyproject.toml * fix: all-extras arg * fix: lint in action * feat: via param * fix: added all extras * fix: Added static type checker * feat: Moved body creationto function * feat: switched most responses to use dataclasses_json * feat: Added base mixin * fix: CR * fix: test path * fix: CR * feat: Added bedrock session * feat: Added SageMakerSession * fix: init of bedrock client * feat: More robust imports * fix: error message * fix: removed kwargs from request body * fix: Removed log_level from env * fix: logger calls * fix: Removed logger from init * feat: Added setup logger * ci: Added integration tests only on push to main * fix: removed unused import
- Loading branch information
1 parent
a2841de
commit 0e9b36a
Showing
125 changed files
with
4,748 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# This workflow will upload a Python Package using Twine when a release is created | ||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries | ||
|
||
name: Publish to PYPI | ||
|
||
on: | ||
release: | ||
types: [published] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
deploy: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Install Poetry | ||
run: | | ||
pipx install poetry | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: poetry | ||
cache-dependency-path: poetry.lock | ||
- name: Set Poetry environment | ||
run: | | ||
poetry env use ${{ matrix.python-version }} | ||
- name: Build package | ||
run: poetry build | ||
- name: Publish package to PYPI | ||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 | ||
with: | ||
user: __token__ | ||
password: ${{ secrets.PYPI_API_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
name: Semantic Release | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
jobs: | ||
release: | ||
runs-on: ubuntu-latest | ||
concurrency: release | ||
permissions: | ||
id-token: write | ||
contents: write | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
fetch-depth: 0 | ||
persist-credentials: false | ||
|
||
- name: Python Semantic Release | ||
uses: python-semantic-release/[email protected] | ||
with: | ||
github_token: ${{ secrets.GH_PAT_SEM_REL_ASAFG }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
name: Test | ||
|
||
on: [push] | ||
|
||
env: | ||
POETRY_VERSION: "1.7.1" | ||
POETRY_URL: https://install.python-poetry.org | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
- name: Install Poetry | ||
run: | | ||
pipx install poetry | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: poetry | ||
cache-dependency-path: poetry.lock | ||
- name: Set Poetry environment | ||
run: | | ||
poetry env use ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
poetry install --no-root --only dev --all-extras | ||
- name: Lint Python (Black) | ||
run: | | ||
poetry run inv formatter | ||
- name: Lint Python (Ruff) | ||
run: | | ||
poetry run inv lint | ||
- name: Lint Python (isort) | ||
run: | | ||
poetry run inv isort | ||
unittests: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
- name: Install Poetry | ||
run: | | ||
pipx install poetry | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: poetry | ||
cache-dependency-path: poetry.lock | ||
- name: Set Poetry environment | ||
run: | | ||
poetry env use ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
poetry install --all-extras | ||
- name: Run Tests | ||
env: | ||
AI21_API_KEY: ${{ secrets.AI21_API_KEY }} | ||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} | ||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} | ||
run: | | ||
poetry run pytest | ||
- name: Upload pytest test results | ||
uses: actions/upload-artifact@v3 | ||
with: | ||
name: pytest-results-${{ matrix.python-version }} | ||
path: junit/test-results-${{ matrix.python-version }}.xml | ||
# Use always() to always run this step to publish test results when there are test failures | ||
if: ${{ always() }} | ||
|
||
integration-tests: | ||
runs-on: ubuntu-latest | ||
|
||
if: github.ref == 'refs/heads/main' | ||
|
||
strategy: | ||
matrix: | ||
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
- name: Install Poetry | ||
run: | | ||
pipx install poetry | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: poetry | ||
cache-dependency-path: poetry.lock | ||
- name: Set Poetry environment | ||
run: | | ||
poetry env use ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
poetry install --all-extras | ||
- name: Run Integration Tests | ||
env: | ||
AI21_API_KEY: ${{ secrets.AI21_API_KEY }} | ||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} | ||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} | ||
run: | | ||
poetry run pytest tests/integration_tests/ | ||
- name: Upload pytest integration tests results | ||
uses: actions/upload-artifact@v3 | ||
with: | ||
name: pytest-results-${{ matrix.python-version }} | ||
path: junit/test-results-${{ matrix.python-version }}.xml | ||
# Use always() to always run this step to publish test results when there are test failures | ||
if: ${{ always() }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.10.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Any | ||
|
||
from ai21.clients.studio.ai21_client import AI21Client | ||
from ai21.logger import setup_logger | ||
from ai21.resources.responses.answer_response import AnswerResponse | ||
from ai21.resources.responses.chat_response import ChatResponse | ||
from ai21.resources.responses.completion_response import CompletionsResponse | ||
from ai21.resources.responses.custom_model_response import CustomBaseModelResponse | ||
from ai21.resources.responses.dataset_response import DatasetResponse | ||
from ai21.resources.responses.embed_response import EmbedResponse | ||
from ai21.resources.responses.file_response import FileResponse | ||
from ai21.resources.responses.gec_response import GECResponse | ||
from ai21.resources.responses.improvement_response import ImprovementsResponse | ||
from ai21.resources.responses.library_answer_response import LibraryAnswerResponse | ||
from ai21.resources.responses.library_search_response import LibrarySearchResponse | ||
from ai21.resources.responses.paraphrase_response import ParaphraseResponse | ||
from ai21.resources.responses.segmentation_response import SegmentationResponse | ||
from ai21.resources.responses.summarize_by_segment_response import SummarizeBySegmentResponse | ||
from ai21.resources.responses.summarize_response import SummarizeResponse | ||
from ai21.services.sagemaker import SageMaker | ||
from ai21.version import VERSION | ||
|
||
__version__ = VERSION | ||
setup_logger() | ||
|
||
|
||
def _import_bedrock_client(): | ||
from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient | ||
|
||
return AI21BedrockClient | ||
|
||
|
||
def _import_sagemaker_client(): | ||
from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient | ||
|
||
return AI21SageMakerClient | ||
|
||
|
||
def _import_bedrock_model_id(): | ||
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID | ||
|
||
return BedrockModelID | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
try: | ||
if name == "AI21BedrockClient": | ||
return _import_bedrock_client() | ||
|
||
if name == "AI21SageMakerClient": | ||
return _import_sagemaker_client() | ||
|
||
if name == "BedrockModelID": | ||
return _import_bedrock_model_id() | ||
except ImportError as e: | ||
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e | ||
|
||
|
||
__all__ = [ | ||
"AI21Client", | ||
"AI21BedrockClient", | ||
"AI21SageMakerClient", | ||
"BedrockModelID", | ||
"AnswerResponse", | ||
"ChatResponse", | ||
"CompletionsResponse", | ||
"CustomBaseModelResponse", | ||
"DatasetResponse", | ||
"EmbedResponse", | ||
"FileResponse", | ||
"GECResponse", | ||
"ImprovementsResponse", | ||
"LibraryAnswerResponse", | ||
"LibrarySearchResponse", | ||
"ParaphraseResponse", | ||
"SageMaker", | ||
"SegmentationResponse", | ||
"SummarizeBySegmentResponse", | ||
"SummarizeResponse", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from __future__ import annotations | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST | ||
|
||
|
||
@dataclass(frozen=True) | ||
class _AI21EnvConfig: | ||
api_key: Optional[str] = None | ||
api_url: Optional[str] = None | ||
api_version: str = DEFAULT_API_VERSION | ||
api_host: str = STUDIO_HOST | ||
organization: Optional[str] = None | ||
application: Optional[str] = None | ||
timeout_sec: Optional[int] = None | ||
num_retries: Optional[int] = None | ||
aws_region: Optional[str] = None | ||
log_level: Optional[str] = None | ||
|
||
@classmethod | ||
def from_env(cls) -> _AI21EnvConfig: | ||
return cls( | ||
api_key=os.getenv("AI21_API_KEY"), | ||
api_url=os.getenv("AI21_API_URL"), | ||
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION), | ||
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST), | ||
organization=os.getenv("AI21_ORGANIZATION"), | ||
application=os.getenv("AI21_APPLICATION"), | ||
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"), | ||
num_retries=os.getenv("AI21_NUM_RETRIES"), | ||
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"), | ||
log_level=os.getenv("AI21_LOG_LEVEL", "info"), | ||
) | ||
|
||
|
||
AI21EnvConfig = _AI21EnvConfig.from_env() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Optional, Dict, Any | ||
|
||
from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig | ||
from ai21.errors import MissingApiKeyException | ||
from ai21.http_client import HttpClient | ||
from ai21.version import VERSION | ||
|
||
|
||
class AI21StudioClient: | ||
def __init__( | ||
self, | ||
*, | ||
api_key: Optional[str] = None, | ||
api_host: Optional[str] = None, | ||
api_version: Optional[str] = None, | ||
headers: Optional[Dict[str, Any]] = None, | ||
timeout_sec: Optional[int] = None, | ||
num_retries: Optional[int] = None, | ||
organization: Optional[str] = None, | ||
via: Optional[str] = None, | ||
env_config: _AI21EnvConfig = AI21EnvConfig, | ||
): | ||
self._env_config = env_config | ||
self._api_key = api_key or self._env_config.api_key | ||
|
||
if self._api_key is None: | ||
raise MissingApiKeyException() | ||
|
||
self._api_host = api_host or self._env_config.api_host | ||
self._api_version = api_version or self._env_config.api_version | ||
self._headers = headers | ||
self._timeout_sec = timeout_sec or self._env_config.timeout_sec | ||
self._num_retries = num_retries or self._env_config.num_retries | ||
self._organization = organization or self._env_config.organization | ||
self._application = self._env_config.application | ||
self._via = via | ||
|
||
headers = self._build_headers(passed_headers=headers) | ||
|
||
self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers) | ||
|
||
def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]: | ||
headers = { | ||
"Content-Type": "application/json", | ||
"User-Agent": self._build_user_agent(), | ||
} | ||
|
||
if self._api_key: | ||
headers["Authorization"] = f"Bearer {self._api_key}" | ||
|
||
if passed_headers is not None: | ||
headers.update(passed_headers) | ||
|
||
return headers | ||
|
||
def _build_user_agent(self) -> str: | ||
user_agent = f"ai21 studio SDK {VERSION}" | ||
|
||
if self._organization is not None: | ||
user_agent = f"{user_agent} organization: {self._organization}" | ||
|
||
if self._application is not None: | ||
user_agent = f"{user_agent} application: {self._application}" | ||
|
||
if self._via is not None: | ||
user_agent = f"{user_agent} via: {self._via}" | ||
|
||
return user_agent | ||
|
||
def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None): | ||
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files) | ||
|
||
def get_base_url(self) -> str: | ||
return f"{self._api_host}/studio/{self._api_version}" |
Empty file.
Empty file.
Oops, something went wrong.