Skip to content

Commit

Permalink
feat: SDK code (#3)
Browse files Browse the repository at this point in the history
* feat: Added code

* feat: Added setup

* feat: more sdk code

* feat: poetry

* feat: poetry setup

* ci: actions

* fix: removed unused example

* feat: added boto3

* feat: added dependencies

* test: added pytest dependency

* fix: python version

* fix: update python version in lock

* fix: format

* fix: examples

* test: removed example script for studio and added integration test instead

* test: bedrock integration test

* test: moved examples

* ci: fixed inv

* fix: lint

* feat: version in init

* fix: long content

* fix: poetry version

* fix: added __all__

* fix: Added code to __all__

* fix: prompt

* fix: test action

* fix: Added shebang

* fix: long line

* fix: loaded env for tests

* fix: Added env

* test: only 3.10

* test: default region

* test: Added 3.8

* fix: subscriptable type

* test: sagemaker tests

* fix: used _http methods

* fix: default values

* ci: removed -vv flag

* fix: imports

* test: Added conditional skip

* fix: CR fixes

* fix: boto3 to pyproject.toml

* fix: all-extras arg

* fix: lint in action

* feat: via param

* fix: added all extras

* fix: Added static type checker

* feat: Moved body creationto function

* feat: switched most responses to use dataclasses_json

* feat: Added base mixin

* fix: CR

* fix: test path

* fix: CR

* feat: Added bedrock session

* feat: Added SageMakerSession

* fix: init of bedrock client

* feat: More robust imports

* fix: error message

* fix: removed kwargs from request body

* fix: Removed log_level from env

* fix: logger calls

* fix: Removed logger from init

* feat: Added setup logger

* ci: Added integration tests only on push to main

* fix: removed unused import
  • Loading branch information
asafgardin authored Dec 18, 2023
1 parent a2841de commit 0e9b36a
Show file tree
Hide file tree
Showing 125 changed files with 4,748 additions and 0 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries

name: Publish to PYPI

on:
release:
types: [published]

permissions:
contents: read

jobs:
deploy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Build package
run: poetry build
- name: Publish package to PYPI
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
23 changes: 23 additions & 0 deletions .github/workflows/release_version.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Semantic Release

on:
workflow_dispatch:

jobs:
release:
runs-on: ubuntu-latest
concurrency: release
permissions:
id-token: write
contents: write

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
persist-credentials: false

- name: Python Semantic Release
uses: python-semantic-release/[email protected]
with:
github_token: ${{ secrets.GH_PAT_SEM_REL_ASAFG }}
120 changes: 120 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
name: Test

on: [push]

env:
POETRY_VERSION: "1.7.1"
POETRY_URL: https://install.python-poetry.org

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --no-root --only dev --all-extras
- name: Lint Python (Black)
run: |
poetry run inv formatter
- name: Lint Python (Ruff)
run: |
poetry run inv lint
- name: Lint Python (isort)
run: |
poetry run inv isort
unittests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --all-extras
- name: Run Tests
env:
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
poetry run pytest
- name: Upload pytest test results
uses: actions/upload-artifact@v3
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}

integration-tests:
runs-on: ubuntu-latest

if: github.ref == 'refs/heads/main'

strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Install Poetry
run: |
pipx install poetry
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: poetry
cache-dependency-path: poetry.lock
- name: Set Poetry environment
run: |
poetry env use ${{ matrix.python-version }}
- name: Install dependencies
run: |
poetry install --all-extras
- name: Run Integration Tests
env:
AI21_API_KEY: ${{ secrets.AI21_API_KEY }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
poetry run pytest tests/integration_tests/
- name: Upload pytest integration tests results
uses: actions/upload-artifact@v3
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.10.6
80 changes: 80 additions & 0 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any

from ai21.clients.studio.ai21_client import AI21Client
from ai21.logger import setup_logger
from ai21.resources.responses.answer_response import AnswerResponse
from ai21.resources.responses.chat_response import ChatResponse
from ai21.resources.responses.completion_response import CompletionsResponse
from ai21.resources.responses.custom_model_response import CustomBaseModelResponse
from ai21.resources.responses.dataset_response import DatasetResponse
from ai21.resources.responses.embed_response import EmbedResponse
from ai21.resources.responses.file_response import FileResponse
from ai21.resources.responses.gec_response import GECResponse
from ai21.resources.responses.improvement_response import ImprovementsResponse
from ai21.resources.responses.library_answer_response import LibraryAnswerResponse
from ai21.resources.responses.library_search_response import LibrarySearchResponse
from ai21.resources.responses.paraphrase_response import ParaphraseResponse
from ai21.resources.responses.segmentation_response import SegmentationResponse
from ai21.resources.responses.summarize_by_segment_response import SummarizeBySegmentResponse
from ai21.resources.responses.summarize_response import SummarizeResponse
from ai21.services.sagemaker import SageMaker
from ai21.version import VERSION

__version__ = VERSION
setup_logger()


def _import_bedrock_client():
from ai21.clients.bedrock.ai21_bedrock_client import AI21BedrockClient

return AI21BedrockClient


def _import_sagemaker_client():
from ai21.clients.sagemaker.ai21_sagemaker_client import AI21SageMakerClient

return AI21SageMakerClient


def _import_bedrock_model_id():
from ai21.clients.bedrock.bedrock_model_id import BedrockModelID

return BedrockModelID


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
return _import_bedrock_client()

if name == "AI21SageMakerClient":
return _import_sagemaker_client()

if name == "BedrockModelID":
return _import_bedrock_model_id()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e


__all__ = [
"AI21Client",
"AI21BedrockClient",
"AI21SageMakerClient",
"BedrockModelID",
"AnswerResponse",
"ChatResponse",
"CompletionsResponse",
"CustomBaseModelResponse",
"DatasetResponse",
"EmbedResponse",
"FileResponse",
"GECResponse",
"ImprovementsResponse",
"LibraryAnswerResponse",
"LibrarySearchResponse",
"ParaphraseResponse",
"SageMaker",
"SegmentationResponse",
"SummarizeBySegmentResponse",
"SummarizeResponse",
]
38 changes: 38 additions & 0 deletions ai21/ai21_env_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Optional

from ai21.constants import DEFAULT_API_VERSION, STUDIO_HOST


@dataclass(frozen=True)
class _AI21EnvConfig:
api_key: Optional[str] = None
api_url: Optional[str] = None
api_version: str = DEFAULT_API_VERSION
api_host: str = STUDIO_HOST
organization: Optional[str] = None
application: Optional[str] = None
timeout_sec: Optional[int] = None
num_retries: Optional[int] = None
aws_region: Optional[str] = None
log_level: Optional[str] = None

@classmethod
def from_env(cls) -> _AI21EnvConfig:
return cls(
api_key=os.getenv("AI21_API_KEY"),
api_url=os.getenv("AI21_API_URL"),
api_version=os.getenv("AI21_API_VERSION", DEFAULT_API_VERSION),
api_host=os.getenv("AI21_API_HOST", STUDIO_HOST),
organization=os.getenv("AI21_ORGANIZATION"),
application=os.getenv("AI21_APPLICATION"),
timeout_sec=os.getenv("AI21_TIMEOUT_SEC"),
num_retries=os.getenv("AI21_NUM_RETRIES"),
aws_region=os.getenv("AI21_AWS_REGION", "us-east-1"),
log_level=os.getenv("AI21_LOG_LEVEL", "info"),
)


AI21EnvConfig = _AI21EnvConfig.from_env()
74 changes: 74 additions & 0 deletions ai21/ai21_studio_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Dict, Any

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.errors import MissingApiKeyException
from ai21.http_client import HttpClient
from ai21.version import VERSION


class AI21StudioClient:
def __init__(
self,
*,
api_key: Optional[str] = None,
api_host: Optional[str] = None,
api_version: Optional[str] = None,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[int] = None,
num_retries: Optional[int] = None,
organization: Optional[str] = None,
via: Optional[str] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
self._env_config = env_config
self._api_key = api_key or self._env_config.api_key

if self._api_key is None:
raise MissingApiKeyException()

self._api_host = api_host or self._env_config.api_host
self._api_version = api_version or self._env_config.api_version
self._headers = headers
self._timeout_sec = timeout_sec or self._env_config.timeout_sec
self._num_retries = num_retries or self._env_config.num_retries
self._organization = organization or self._env_config.organization
self._application = self._env_config.application
self._via = via

headers = self._build_headers(passed_headers=headers)

self.http_client = HttpClient(timeout_sec=timeout_sec, num_retries=num_retries, headers=headers)

def _build_headers(self, passed_headers: Optional[Dict[str, Any]]) -> Dict[str, Any]:
headers = {
"Content-Type": "application/json",
"User-Agent": self._build_user_agent(),
}

if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"

if passed_headers is not None:
headers.update(passed_headers)

return headers

def _build_user_agent(self) -> str:
user_agent = f"ai21 studio SDK {VERSION}"

if self._organization is not None:
user_agent = f"{user_agent} organization: {self._organization}"

if self._application is not None:
user_agent = f"{user_agent} application: {self._application}"

if self._via is not None:
user_agent = f"{user_agent} via: {self._via}"

return user_agent

def execute_http_request(self, method: str, url: str, params: Optional[Dict] = None, files=None):
return self.http_client.execute_http_request(method=method, url=url, params=params, files=files)

def get_base_url(self) -> str:
return f"{self._api_host}/studio/{self._api_version}"
Empty file added ai21/clients/__init__.py
Empty file.
Empty file.
Loading

0 comments on commit 0e9b36a

Please sign in to comment.