Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ add support for Thread resource #232

Merged
merged 3 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions ai21/clients/common/assistant/thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import List

from ai21.models.responses.thread_response import ThreadResponse, CreateMessagePayload


class Thread(ABC):
Josephasafg marked this conversation as resolved.
Show resolved Hide resolved
_module_name = "threads"

@abstractmethod
def create(
self,
messages: List[CreateMessagePayload],
**kwargs,
) -> ThreadResponse:
Josephasafg marked this conversation as resolved.
Show resolved Hide resolved
pass

Josephasafg marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def get(self, thread_id: str) -> ThreadResponse:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def create(
**kwargs,
)

return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)
return await self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse)

async def get(self, assistant_id: str) -> AssistantResponse:
return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse)
Expand Down
35 changes: 35 additions & 0 deletions ai21/clients/studio/resources/assistant/studio_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from typing import List

from ai21.clients.common.assistant.thread import Thread
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
from ai21.models.responses.thread_response import CreateMessagePayload, ThreadResponse


class StudioThread(StudioResource, Thread):
def create(
self,
messages: List[CreateMessagePayload],
**kwargs,
) -> ThreadResponse:
body = dict(messages=messages)

return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)

def get(self, thread_id: str) -> ThreadResponse:
return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse)


class AsyncStudioThread(AsyncStudioResource, Thread):
async def create(
self,
messages: List[CreateMessagePayload],
**kwargs,
) -> ThreadResponse:
body = dict(messages=messages)

return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse)

async def get(self, thread_id: str) -> ThreadResponse:
return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse)
4 changes: 3 additions & 1 deletion ai21/clients/studio/resources/beta/async_beta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant
from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread
from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.http_client.async_http_client import AsyncAI21HTTPClient
Expand All @@ -8,5 +9,6 @@ class AsyncBeta(AsyncStudioResource):
def __init__(self, client: AsyncAI21HTTPClient):
super().__init__(client)

self.conversational_rag = AsyncStudioConversationalRag(client)
self.assistants = AsyncStudioAssistant(client)
self.conversational_rag = AsyncStudioConversationalRag(client)
self.threads = AsyncStudioThread(client)
4 changes: 3 additions & 1 deletion ai21/clients/studio/resources/beta/beta.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant
from ai21.clients.studio.resources.assistant.studio_thread import StudioThread
from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.http_client.http_client import AI21HTTPClient
Expand All @@ -8,5 +9,6 @@ class Beta(StudioResource):
def __init__(self, client: AI21HTTPClient):
super().__init__(client)

self.conversational_rag = StudioConversationalRag(client)
self.assistants = StudioAssistant(client)
self.conversational_rag = StudioConversationalRag(client)
self.threads = StudioThread(client)
2 changes: 1 addition & 1 deletion ai21/models/responses/assistant_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AssistantResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: str
object: Literal["assistant"] = "assistant"
name: str
description: Optional[str] = None
optimization: str
Expand Down
41 changes: 41 additions & 0 deletions ai21/models/responses/thread_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime
from typing import Optional, List, Literal

from typing_extensions import TypedDict

from ai21.models.ai21_base_model import AI21BaseModel


MessageRole = Literal["assistant", "user"]


class MessageContentText(TypedDict):
type: Literal["text"]
text: str


class CreateMessagePayload(TypedDict):
Josephasafg marked this conversation as resolved.
Show resolved Hide resolved
role: MessageRole
content: MessageContentText

Josephasafg marked this conversation as resolved.
Show resolved Hide resolved

class ThreadMessageResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["message"] = "message"
role: MessageRole
content: MessageContentText
run_id: Optional[str] = None
assistant_id: Optional[str] = None


class ThreadResponse(AI21BaseModel):
id: str
created_at: datetime
updated_at: datetime
object: Literal["thread"] = "thread"


class ListThreadResponse(AI21BaseModel):
results: List[ThreadResponse]
Loading