Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ add support for Thread resource #232

Merged
merged 3 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from ai21.models.responses.assistant_response import (
AssistantResponse,
Optimization,
ToolResources,
Tool,
ListAssistantResponse,
)
from ai21.models.assistant.assistant import Optimization, Tool, ToolResources
from ai21.models.responses.assistant_response import Assistant, ListAssistant
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given


class Assistant(ABC):
class Assistants(ABC):
_module_name = "assistants"

@abstractmethod
Expand All @@ -29,7 +24,7 @@ def create(
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> AssistantResponse:
) -> Assistant:
pass

def _create_body(
Expand Down Expand Up @@ -58,11 +53,11 @@ def _create_body(
)

@abstractmethod
def list(self) -> ListAssistantResponse:
def list(self) -> ListAssistant:
pass

@abstractmethod
def get(self, assistant_id: str) -> AssistantResponse:
def get(self, assistant_id: str) -> Assistant:
pass

@abstractmethod
Expand All @@ -78,5 +73,5 @@ def modify(
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> AssistantResponse:
) -> Assistant:
pass
19 changes: 19 additions & 0 deletions ai21/clients/common/assistant/threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List

from ai21.models.assistant.message import Message
from ai21.models.responses.thread_response import Thread


class Threads(ABC):
_module_name = "threads"

@abstractmethod
def create(self, messages: List[Message], **kwargs) -> Thread:
pass

@abstractmethod
def get(self, thread_id: str) -> Thread:
pass
46 changes: 21 additions & 25 deletions ai21/clients/studio/resources/assistant/studio_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,17 @@

from typing import List

from ai21.clients.common.assistant.assistant import Assistant
from ai21.clients.common.assistant.assistants import Assistants
from ai21.clients.studio.resources.studio_resource import (
AsyncStudioResource,
StudioResource,
)
from ai21.models.responses.assistant_response import (
AssistantResponse,
Tool,
ToolResources,
ListAssistantResponse,
)
from ai21.models.assistant.assistant import Tool, ToolResources
from ai21.models.responses.assistant_response import Assistant, ListAssistant
from ai21.types import NotGiven, NOT_GIVEN


class StudioAssistant(StudioResource, Assistant):
class StudioAssistant(StudioResource, Assistants):
def create(
self,
name: str,
Expand All @@ -28,7 +24,7 @@ def create(
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> AssistantResponse:
) -> Assistant:
body = self._create_body(
name=name,
description=description,
Expand All @@ -40,13 +36,13 @@ def create(
**kwargs,
)

return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)
return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)

def get(self, assistant_id: str) -> AssistantResponse:
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)
def get(self, assistant_id: str) -> Assistant:
return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)

def list(self) -> ListAssistantResponse:
return self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse)
def list(self) -> ListAssistant:
return self._get(path=f"/{self._module_name}", response_cls=ListAssistant)

def modify(
self,
Expand All @@ -60,7 +56,7 @@ def modify(
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> AssistantResponse:
) -> Assistant:
body = self._create_body(
name=name,
description=description,
Expand All @@ -72,10 +68,10 @@ def modify(
tool_resources=tool_resources,
)

return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)
return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)


class AsyncStudioAssistant(AsyncStudioResource, Assistant):
class AsyncStudioAssistant(AsyncStudioResource, Assistants):
async def create(
self,
name: str,
Expand All @@ -87,7 +83,7 @@ async def create(
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
**kwargs,
) -> AssistantResponse:
) -> Assistant:
body = self._create_body(
name=name,
description=description,
Expand All @@ -99,13 +95,13 @@ async def create(
**kwargs,
)

return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)
return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant)

async def get(self, assistant_id: str) -> AssistantResponse:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)
async def get(self, assistant_id: str) -> Assistant:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant)

async def list(self) -> ListAssistantResponse:
return await self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse)
async def list(self) -> ListAssistant:
return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant)

async def modify(
self,
Expand All @@ -119,7 +115,7 @@ async def modify(
models: List[str] | NotGiven = NOT_GIVEN,
tools: List[Tool] | NotGiven = NOT_GIVEN,
tool_resources: ToolResources | NotGiven = NOT_GIVEN,
) -> AssistantResponse:
) -> Assistant:
body = self._create_body(
name=name,
description=description,
Expand All @@ -131,4 +127,4 @@ async def modify(
tool_resources=tool_resources,
)

return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse)
return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant)
28 changes: 28 additions & 0 deletions ai21/clients/studio/resources/assistant/studio_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from typing import List

from ai21.clients.common.assistant.threads import Threads
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.assistant.message import Message
from ai21.models.responses.thread_response import Thread


class StudioThread(StudioResource, Threads):
def create(self, messages: List[Message], **kwargs) -> Thread:
body = dict(messages=messages)

return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)

def get(self, thread_id: str) -> Thread:
return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)


class AsyncStudioThread(AsyncStudioResource, Threads):
async def create(self, messages: List[Message], **kwargs) -> Thread:
body = dict(messages=messages)

return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread)

async def get(self, thread_id: str) -> Thread:
return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread)
4 changes: 3 additions & 1 deletion ai21/clients/studio/resources/beta/async_beta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant
from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread
from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
Expand All @@ -8,5 +9,6 @@ class AsyncBeta(AsyncStudioResource):
def __init__(self, client: AsyncAI21HTTPClient):
super().__init__(client)

self.conversational_rag = AsyncStudioConversationalRag(client)
self.assistants = AsyncStudioAssistant(client)
self.conversational_rag = AsyncStudioConversationalRag(client)
self.threads = AsyncStudioThread(client)
4 changes: 3 additions & 1 deletion ai21/clients/studio/resources/beta/beta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant
from ai21.clients.studio.resources.assistant.studio_thread import StudioThread
from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.http_client.http_client import AI21HTTPClient
Expand All @@ -8,5 +9,6 @@ class Beta(StudioResource):
def __init__(self, client: AI21HTTPClient):
super().__init__(client)

self.conversational_rag = StudioConversationalRag(client)
self.assistants = StudioAssistant(client)
self.conversational_rag = StudioConversationalRag(client)
self.threads = StudioThread(client)
Empty file.
12 changes: 12 additions & 0 deletions ai21/models/assistant/assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional, Literal

from typing_extensions import TypedDict

Optimization = Literal["cost", "latency"]
Tool = Literal["rag", "internet_research", "plan_approval"]


class ToolResources(TypedDict, total=False):
rag: Optional[dict]
internet_research: Optional[dict]
plan_approval: Optional[dict]
29 changes: 29 additions & 0 deletions ai21/models/assistant/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from datetime import datetime
from typing import Literal, Optional

from typing_extensions import TypedDict

from ai21.models.ai21_base_model import AI21BaseModel

ThreadMessageRole = Literal["assistant", "user"]


class MessageContentText(TypedDict):
type: Literal["text"]
text: str


class Message(TypedDict):
role: ThreadMessageRole
content: MessageContentText


class MessageResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["message"] = "message"
role: ThreadMessageRole
content: MessageContentText
run_id: Optional[str] = None
assistant_id: Optional[str] = None
21 changes: 5 additions & 16 deletions ai21/models/responses/assistant_response.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
from datetime import datetime
from typing import Optional, List, Literal

from typing_extensions import TypedDict

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.assistant.assistant import ToolResources


Optimization = Literal["cost", "latency"]
Tool = Literal["rag", "internet_research", "plan_approval"]


class ToolResources(TypedDict, total=False):
rag: Optional[dict]
internet_research: Optional[dict]
plan_approval: Optional[dict]


class AssistantResponse(AI21BaseModel):
class Assistant(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: str
object: Literal["assistant"] = "assistant"
name: str
description: Optional[str] = None
optimization: str
Expand All @@ -33,5 +22,5 @@ class AssistantResponse(AI21BaseModel):
tool_resources: Optional[ToolResources] = None


class ListAssistantResponse(AI21BaseModel):
results: List[AssistantResponse]
class ListAssistant(AI21BaseModel):
results: List[Assistant]
15 changes: 15 additions & 0 deletions ai21/models/responses/thread_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datetime import datetime
from typing import List, Literal

from ai21.models.ai21_base_model import AI21BaseModel


class Thread(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["thread"] = "thread"


class ListThread(AI21BaseModel):
results: List[Thread]
Loading