Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SDK code #3

Merged
merged 65 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
65420f6
feat: Added code
asafgardin Dec 11, 2023
028b368
feat: Added setup
asafgardin Dec 11, 2023
bafe1f1
feat: more sdk code
asafgardin Dec 11, 2023
db5515f
feat: poetry
asafgardin Dec 11, 2023
9115a72
feat: poetry setup
asafgardin Dec 11, 2023
cf0a521
ci: actions
asafgardin Dec 11, 2023
f0e54b1
fix: removed unused example
asafgardin Dec 12, 2023
d4215d5
feat: added boto3
asafgardin Dec 12, 2023
cfd531b
feat: added dependencies
asafgardin Dec 12, 2023
29b974c
test: added pytest dependency
asafgardin Dec 12, 2023
6a6e455
fix: python version
asafgardin Dec 12, 2023
908baf4
fix: update python version in lock
asafgardin Dec 12, 2023
aef1314
fix: format
asafgardin Dec 12, 2023
0d7c7ad
fix: examples
asafgardin Dec 12, 2023
af282ba
test: removed example script for studio and added integration test in…
asafgardin Dec 12, 2023
4909f40
test: bedrock integration test
asafgardin Dec 12, 2023
42bd27c
test: moved examples
asafgardin Dec 12, 2023
a35357c
ci: fixed inv
asafgardin Dec 12, 2023
52dd875
fix: lint
asafgardin Dec 12, 2023
a089b2e
feat: version in init
asafgardin Dec 12, 2023
1cbb45c
fix: long content
asafgardin Dec 12, 2023
9fa54b2
fix: poetry version
asafgardin Dec 12, 2023
69c9946
fix: added __all__
asafgardin Dec 12, 2023
61e465c
fix: Added code to __all__
asafgardin Dec 12, 2023
824c492
fix: prompt
asafgardin Dec 12, 2023
496542a
fix: test action
asafgardin Dec 12, 2023
26ec177
fix: Added shebang
asafgardin Dec 12, 2023
80591e3
fix: long line
asafgardin Dec 12, 2023
ec0dc35
fix: loaded env for tests
asafgardin Dec 12, 2023
9102073
fix: Added env
asafgardin Dec 12, 2023
1093e5c
test: only 3.10
asafgardin Dec 12, 2023
98f12b0
test: default region
asafgardin Dec 12, 2023
cdaa0fd
test: Added 3.8
asafgardin Dec 12, 2023
71d1cc9
fix: subscriptable type
asafgardin Dec 12, 2023
5b753ea
test: sagemaker tests
asafgardin Dec 12, 2023
e7babfc
fix: used _http methods
asafgardin Dec 12, 2023
c039dce
fix: default values
asafgardin Dec 12, 2023
ac14452
ci: removed -vv flag
asafgardin Dec 12, 2023
11f45f7
fix: imports
asafgardin Dec 12, 2023
67039a0
test: Added conditional skip
asafgardin Dec 12, 2023
4f40004
fix: CR fixes
asafgardin Dec 13, 2023
0cb4eaa
fix: boto3 to pyproject.toml
asafgardin Dec 13, 2023
39d69a9
fix: all-extras arg
asafgardin Dec 14, 2023
1d820a8
fix: lint in action
asafgardin Dec 14, 2023
d51f101
feat: via param
asafgardin Dec 14, 2023
f00fd8c
fix: added all extras
asafgardin Dec 14, 2023
3a48e8c
fix: Added static type checker
asafgardin Dec 14, 2023
ca9db03
feat: Moved body creationto function
asafgardin Dec 14, 2023
3f8a154
feat: switched most responses to use dataclasses_json
asafgardin Dec 17, 2023
fb695c3
feat: Added base mixin
asafgardin Dec 17, 2023
63a22bd
fix: CR
asafgardin Dec 17, 2023
dc3523b
fix: test path
asafgardin Dec 17, 2023
29c1bc4
fix: CR
asafgardin Dec 17, 2023
62dac90
feat: Added bedrock session
asafgardin Dec 17, 2023
5446c69
feat: Added SageMakerSession
asafgardin Dec 17, 2023
c8057d9
fix: init of bedrock client
asafgardin Dec 17, 2023
29c24e0
feat: More robust imports
asafgardin Dec 17, 2023
6f8a634
fix: error message
asafgardin Dec 17, 2023
be46ab1
fix: removed kwargs from request body
asafgardin Dec 17, 2023
8a2966e
fix: Removed log_level from env
asafgardin Dec 17, 2023
0e8c60e
fix: logger calls
asafgardin Dec 17, 2023
142d7e1
fix: Removed logger from init
asafgardin Dec 17, 2023
4ad7a76
feat: Added setup logger
asafgardin Dec 18, 2023
759d95b
ci: Added integration tests only on push to main
asafgardin Dec 18, 2023
6141960
fix: removed unused import
asafgardin Dec 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries

name: Publish to PYPI

on:
release:
types: [published]

permissions:
contents: read

jobs:
deploy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Build package
run: poetry build
- name: Publish package to PYPI
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
23 changes: 23 additions & 0 deletions .github/workflows/release_version.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Semantic Release

on:
workflow_dispatch:

jobs:
release:
runs-on: ubuntu-latest
concurrency: release
permissions:
id-token: write
contents: write

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false

- name: Python Semantic Release
uses: python-semantic-release/[email protected]
with:
github_token: ${{ secrets.GH_PAT_SEM_REL_ASAFG }}
120 changes: 120 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
name: Test

on: [push]

env:
POETRY_VERSION: "1.7.1"
POETRY_URL: https://install.python-poetry.org

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --no-root --only dev --all-extras
- name: Lint Python (Black)
run: |
poetry run inv formatter
- name: Lint Python (Ruff)
run: |
poetry run inv lint
- name: Lint Python (isort)
run: |
poetry run inv isort
unittests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --all-extras
- name: Run Tests
env:
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
poetry run pytest
- name: Upload pytest test results
uses: actions/upload-artifact@v3
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}

integration-tests:
runs-on: ubuntu-latest

if: github.ref == 'refs/heads/main'

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --all-extras
- name: Run Integration Tests
env:
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
poetry run pytest tests/integration_tests/
- name: Upload pytest integration tests results
uses: actions/upload-artifact@v3
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.6
80 changes: 80 additions & 0 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any

from ai21.clients.studio.ai21_client import AI21Client
from ai21.logger import setup_logger
from ai21.resources.responses.answer_response import AnswerResponse
from ai21.resources.responses.chat_response import ChatResponse
from ai21.resources.responses.completion_response import CompletionsResponse
from ai21.resources.responses.custom_model_response import CustomBaseModelResponse
from ai21.resources.responses.dataset_response import DatasetResponse
from ai21.resources.responses.embed_response import EmbedResponse
from ai21.resources.responses.file_response import FileResponse
from ai21.resources.responses.gec_response import GECResponse
from ai21.resources.responses.improvement_response import ImprovementsResponse
from ai21.resources.responses.library_answer_response import LibraryAnswerResponse
from ai21.resources.responses.library_search_response import LibrarySearchResponse
from ai21.resources.responses.paraphrase_response import ParaphraseResponse
from ai21.resources.responses.segmentation_response import SegmentationResponse
from ai21.resources.responses.summarize_by_segment_response import SummarizeBySegmentResponse
from ai21.resources.responses.summarize_response import SummarizeResponse
from ai21.services.sagemaker import SageMaker
from ai21.version import VERSION

__version__ = VERSION
setup_logger()


def _import_bedrock_client():
from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient

return AI21BedrockClient


def _import_sagemaker_client():
from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient

return AI21SageMakerClient


def _import_bedrock_model_id():
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID

return BedrockModelID


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
return _import_bedrock_client()

if name == "AI21SageMakerClient":
return _import_sagemaker_client()

if name == "BedrockModelID":
return _import_bedrock_model_id()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e


__all__ = [
"AI21Client",
"AI21BedrockClient",
"AI21SageMakerClient",
"BedrockModelID",
"AnswerResponse",
"ChatResponse",
"CompletionsResponse",
"CustomBaseModelResponse",
"DatasetResponse",
"EmbedResponse",
"FileResponse",
"GECResponse",
"ImprovementsResponse",
"LibraryAnswerResponse",
"LibrarySearchResponse",
"ParaphraseResponse",
"SageMaker",
"SegmentationResponse",
"SummarizeBySegmentResponse",
"SummarizeResponse",
]
38 changes: 38 additions & 0 deletions ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Optional

from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST


@dataclass(frozen=True)
class _AI21EnvConfig:
api_key: Optional[str] = None
api_url: Optional[str] = None
api_version: str = DEFAULT_API_VERSION
api_host: str = STUDIO_HOST
organization: Optional[str] = None
application: Optional[str] = None
timeout_sec: Optional[int] = None
num_retries: Optional[int] = None
aws_region: Optional[str] = None
log_level: Optional[str] = None

@classmethod
def from_env(cls) -> _AI21EnvConfig:
return cls(
api_key=os.getenv("AI21_API_KEY"),
api_url=os.getenv("AI21_API_URL"),
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION),
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST),
organization=os.getenv("AI21_ORGANIZATION"),
application=os.getenv("AI21_APPLICATION"),
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"),
num_retries=os.getenv("AI21_NUM_RETRIES"),
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"),
log_level=os.getenv("AI21_LOG_LEVEL", "info"),
)


AI21EnvConfig = _AI21EnvConfig.from_env()
74 changes: 74 additions & 0 deletions ai21/ai21_studio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Dict, Any

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyException
from ai21.http_client import HttpClient
from ai21.version import VERSION


class AI21StudioClient:
def __init__(
self,
*,
api_key: Optional[str] = None,
api_host: Optional[str] = None,
api_version: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[int] = None,
num_retries: Optional[int] = None,
organization: Optional[str] = None,
via: Optional[str] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
self._env_config = env_config
self._api_key = api_key or self._env_config.api_key

if self._api_key is None:
raise MissingApiKeyException()

self._api_host = api_host or self._env_config.api_host
self._api_version = api_version or self._env_config.api_version
self._headers = headers
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
self._num_retries = num_retries or self._env_config.num_retries
self._organization = organization or self._env_config.organization
self._application = self._env_config.application
self._via = via

headers = self._build_headers(passed_headers=headers)

self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers)

def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"User-Agent": self._build_user_agent(),
}

if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"

if passed_headers is not None:
headers.update(passed_headers)

return headers

def _build_user_agent(self) -> str:
user_agent = f"ai21 studio SDK {VERSION}"

if self._organization is not None:
user_agent = f"{user_agent} organization: {self._organization}"

if self._application is not None:
user_agent = f"{user_agent} application: {self._application}"

if self._via is not None:
user_agent = f"{user_agent} via: {self._via}"

return user_agent

def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None):
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
Empty file added ai21/clients/__init__.py
Empty file.
Empty file.
Loading
Loading