From 62566cd1fd0b42bfcd181ebb5b5032e62a15730b Mon Sep 17 00:00:00 2001 From: benshuk Date: Sun, 1 Dec 2024 13:06:54 +0200 Subject: [PATCH 1/3] feat: :sparkles: add support for Thread resource --- ai21/clients/common/assistant/thread.py | 22 ++++++++++ .../resources/assistant/studio_assistant.py | 2 +- .../resources/assistant/studio_thread.py | 35 ++++++++++++++++ .../studio/resources/beta/async_beta.py | 4 +- ai21/clients/studio/resources/beta/beta.py | 4 +- ai21/models/responses/assistant_response.py | 2 +- ai21/models/responses/thread_response.py | 41 +++++++++++++++++++ 7 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 ai21/clients/common/assistant/thread.py create mode 100644 ai21/clients/studio/resources/assistant/studio_thread.py create mode 100644 ai21/models/responses/thread_response.py diff --git a/ai21/clients/common/assistant/thread.py b/ai21/clients/common/assistant/thread.py new file mode 100644 index 00000000..1a074db2 --- /dev/null +++ b/ai21/clients/common/assistant/thread.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from ai21.models.responses.thread_response import ThreadResponse, CreateMessagePayload + + +class Thread(ABC): + _module_name = "threads" + + @abstractmethod + def create( + self, + messages: List[CreateMessagePayload], + **kwargs, + ) -> ThreadResponse: + pass + + @abstractmethod + def get(self, thread_id: str) -> ThreadResponse: + pass diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index 82014880..2f1a67eb 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -99,7 +99,7 @@ async def create( **kwargs, ) - return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) async def get(self, assistant_id: str) -> AssistantResponse: return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py new file mode 100644 index 00000000..71e54a25 --- /dev/null +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import List + +from ai21.clients.common.assistant.thread import Thread +from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource +from ai21.models.responses.thread_response import CreateMessagePayload, ThreadResponse + + +class StudioThread(StudioResource, Thread): + def create( + self, + messages: List[CreateMessagePayload], + **kwargs, + ) -> ThreadResponse: + body = dict(messages=messages) + + return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) + + def get(self, thread_id: str) -> ThreadResponse: + return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) + + +class AsyncStudioThread(AsyncStudioResource, Thread): + async def create( + self, + messages: List[CreateMessagePayload], + **kwargs, + ) -> ThreadResponse: + body = dict(messages=messages) + + return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) + + async def get(self, thread_id: str) -> ThreadResponse: + return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) diff --git a/ai21/clients/studio/resources/beta/async_beta.py b/ai21/clients/studio/resources/beta/async_beta.py index 521c7a13..1bb13bbe 100644 --- a/ai21/clients/studio/resources/beta/async_beta.py +++ b/ai21/clients/studio/resources/beta/async_beta.py @@ -1,4 +1,5 @@ from ai21.clients.studio.resources.assistant.studio_assistant import AsyncStudioAssistant +from ai21.clients.studio.resources.assistant.studio_thread import AsyncStudioThread from ai21.clients.studio.resources.studio_conversational_rag import AsyncStudioConversationalRag from ai21.clients.studio.resources.studio_resource import AsyncStudioResource from ai21.http_client.async_http_client import AsyncAI21HTTPClient @@ -8,5 +9,6 @@ class AsyncBeta(AsyncStudioResource): def __init__(self, client: AsyncAI21HTTPClient): super().__init__(client) - self.conversational_rag = AsyncStudioConversationalRag(client) self.assistants = AsyncStudioAssistant(client) + self.conversational_rag = AsyncStudioConversationalRag(client) + self.threads = AsyncStudioThread(client) diff --git a/ai21/clients/studio/resources/beta/beta.py b/ai21/clients/studio/resources/beta/beta.py index 8560597a..affede10 100644 --- a/ai21/clients/studio/resources/beta/beta.py +++ b/ai21/clients/studio/resources/beta/beta.py @@ -1,4 +1,5 @@ from ai21.clients.studio.resources.assistant.studio_assistant import StudioAssistant +from ai21.clients.studio.resources.assistant.studio_thread import StudioThread from ai21.clients.studio.resources.studio_conversational_rag import StudioConversationalRag from ai21.clients.studio.resources.studio_resource import StudioResource from ai21.http_client.http_client import AI21HTTPClient @@ -8,5 +9,6 @@ class Beta(StudioResource): def __init__(self, client: AI21HTTPClient): super().__init__(client) - self.conversational_rag = StudioConversationalRag(client) self.assistants = StudioAssistant(client) + self.conversational_rag = StudioConversationalRag(client) + self.threads = StudioThread(client) diff --git a/ai21/models/responses/assistant_response.py b/ai21/models/responses/assistant_response.py index d7b0f186..000b0346 100644 --- a/ai21/models/responses/assistant_response.py +++ b/ai21/models/responses/assistant_response.py @@ -20,7 +20,7 @@ class AssistantResponse(AI21BaseModel): id: str created_at: datetime updated_at: datetime - object: str + object: Literal["assistant"] = "assistant" name: str description: Optional[str] = None optimization: str diff --git a/ai21/models/responses/thread_response.py b/ai21/models/responses/thread_response.py new file mode 100644 index 00000000..662047bb --- /dev/null +++ b/ai21/models/responses/thread_response.py @@ -0,0 +1,41 @@ +from datetime import datetime +from typing import Optional, List, Literal + +from typing_extensions import TypedDict + +from ai21.models.ai21_base_model import AI21BaseModel + + +MessageRole = Literal["assistant", "user"] + + +class MessageContentText(TypedDict): + type: Literal["text"] + text: str + + +class CreateMessagePayload(TypedDict): + role: MessageRole + content: MessageContentText + + +class ThreadMessageResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["message"] = "message" + role: MessageRole + content: MessageContentText + run_id: Optional[str] = None + assistant_id: Optional[str] = None + + +class ThreadResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["thread"] = "thread" + + +class ListThreadResponse(AI21BaseModel): + results: List[ThreadResponse] From ae7a3e40dd13c3feb0b41af8bbef3ed4b2490ca9 Mon Sep 17 00:00:00 2001 From: benshuk Date: Sun, 1 Dec 2024 15:03:19 +0200 Subject: [PATCH 2/3] fix: :truck: rename classes and such --- .../assistant/{assistant.py => assistants.py} | 17 ++++---- ai21/clients/common/assistant/thread.py | 22 ---------- ai21/clients/common/assistant/threads.py | 23 ++++++++++ .../resources/assistant/studio_assistant.py | 42 +++++++++---------- .../resources/assistant/studio_thread.py | 28 ++++++------- ai21/models/assistant/__init__.py | 0 ai21/models/assistant/assistant.py | 12 ++++++ ai21/models/assistant/thread_message.py | 15 +++++++ ai21/models/responses/assistant_response.py | 19 ++------- ai21/models/responses/thread_response.py | 28 ++++--------- 10 files changed, 104 insertions(+), 102 deletions(-) rename ai21/clients/common/assistant/{assistant.py => assistants.py} (88%) delete mode 100644 ai21/clients/common/assistant/thread.py create mode 100644 ai21/clients/common/assistant/threads.py create mode 100644 ai21/models/assistant/__init__.py create mode 100644 ai21/models/assistant/assistant.py create mode 100644 ai21/models/assistant/thread_message.py diff --git a/ai21/clients/common/assistant/assistant.py b/ai21/clients/common/assistant/assistants.py similarity index 88% rename from ai21/clients/common/assistant/assistant.py rename to ai21/clients/common/assistant/assistants.py index c2ad588a..0fd08d7b 100644 --- a/ai21/clients/common/assistant/assistant.py +++ b/ai21/clients/common/assistant/assistants.py @@ -3,18 +3,17 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List +from ai21.models.assistant.assistant import Optimization, Tool from ai21.models.responses.assistant_response import ( - AssistantResponse, - Optimization, + Assistant, ToolResources, - Tool, - ListAssistantResponse, + ListAssistant, ) from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given -class Assistant(ABC): +class Assistants(ABC): _module_name = "assistants" @abstractmethod @@ -29,7 +28,7 @@ def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: pass def _create_body( @@ -58,11 +57,11 @@ def _create_body( ) @abstractmethod - def list(self) -> ListAssistantResponse: + def list(self) -> ListAssistant: pass @abstractmethod - def get(self, assistant_id: str) -> AssistantResponse: + def get(self, assistant_id: str) -> Assistant: pass @abstractmethod @@ -78,5 +77,5 @@ def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: pass diff --git a/ai21/clients/common/assistant/thread.py b/ai21/clients/common/assistant/thread.py deleted file mode 100644 index 1a074db2..00000000 --- a/ai21/clients/common/assistant/thread.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import List - -from ai21.models.responses.thread_response import ThreadResponse, CreateMessagePayload - - -class Thread(ABC): - _module_name = "threads" - - @abstractmethod - def create( - self, - messages: List[CreateMessagePayload], - **kwargs, - ) -> ThreadResponse: - pass - - @abstractmethod - def get(self, thread_id: str) -> ThreadResponse: - pass diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py new file mode 100644 index 00000000..f68bc82b --- /dev/null +++ b/ai21/clients/common/assistant/threads.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List + +from ai21.models.assistant.thread_message import CreateThreadMessagePayload +from ai21.models.responses.thread_response import Thread + + +class Threads(ABC): + _module_name = "threads" + + @abstractmethod + def create( + self, + messages: List[CreateThreadMessagePayload], + **kwargs, + ) -> Thread: + pass + + @abstractmethod + def get(self, thread_id: str) -> Thread: + pass diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index 2f1a67eb..ec14758a 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -2,21 +2,21 @@ from typing import List -from ai21.clients.common.assistant.assistant import Assistant +from ai21.clients.common.assistant.assistants import Assistants from ai21.clients.studio.resources.studio_resource import ( AsyncStudioResource, StudioResource, ) from ai21.models.responses.assistant_response import ( - AssistantResponse, + Assistant, Tool, ToolResources, - ListAssistantResponse, + ListAssistant, ) from ai21.types import NotGiven, NOT_GIVEN -class StudioAssistant(StudioResource, Assistant): +class StudioAssistant(StudioResource, Assistants): def create( self, name: str, @@ -28,7 +28,7 @@ def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -40,13 +40,13 @@ def create( **kwargs, ) - return self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - def get(self, assistant_id: str) -> AssistantResponse: - return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) + def get(self, assistant_id: str) -> Assistant: + return self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) - def list(self) -> ListAssistantResponse: - return self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse) + def list(self) -> ListAssistant: + return self._get(path=f"/{self._module_name}", response_cls=ListAssistant) def modify( self, @@ -60,7 +60,7 @@ def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -72,10 +72,10 @@ def modify( tool_resources=tool_resources, ) - return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) + return self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) -class AsyncStudioAssistant(AsyncStudioResource, Assistant): +class AsyncStudioAssistant(AsyncStudioResource, Assistants): async def create( self, name: str, @@ -87,7 +87,7 @@ async def create( tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, **kwargs, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -99,13 +99,13 @@ async def create( **kwargs, ) - return await self._post(path=f"/{self._module_name}", body=body, response_cls=AssistantResponse) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=Assistant) - async def get(self, assistant_id: str) -> AssistantResponse: - return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=AssistantResponse) + async def get(self, assistant_id: str) -> Assistant: + return await self._get(path=f"/{self._module_name}/{assistant_id}", response_cls=Assistant) - async def list(self) -> ListAssistantResponse: - return await self._get(path=f"/{self._module_name}", response_cls=ListAssistantResponse) + async def list(self) -> ListAssistant: + return await self._get(path=f"/{self._module_name}", response_cls=ListAssistant) async def modify( self, @@ -119,7 +119,7 @@ async def modify( models: List[str] | NotGiven = NOT_GIVEN, tools: List[Tool] | NotGiven = NOT_GIVEN, tool_resources: ToolResources | NotGiven = NOT_GIVEN, - ) -> AssistantResponse: + ) -> Assistant: body = self._create_body( name=name, description=description, @@ -131,4 +131,4 @@ async def modify( tool_resources=tool_resources, ) - return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=AssistantResponse) + return await self._patch(path=f"/{self._module_name}/{assistant_id}", body=body, response_cls=Assistant) diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index 71e54a25..c8aac8cf 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -2,34 +2,34 @@ from typing import List -from ai21.clients.common.assistant.thread import Thread +from ai21.clients.common.assistant.threads import Threads from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models.responses.thread_response import CreateMessagePayload, ThreadResponse +from ai21.models.responses.thread_response import CreateThreadMessagePayload, Thread -class StudioThread(StudioResource, Thread): +class StudioThread(StudioResource, Threads): def create( self, - messages: List[CreateMessagePayload], + messages: List[CreateThreadMessagePayload], **kwargs, - ) -> ThreadResponse: + ) -> Thread: body = dict(messages=messages) - return self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - def get(self, thread_id: str) -> ThreadResponse: - return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) + def get(self, thread_id: str) -> Thread: + return self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) -class AsyncStudioThread(AsyncStudioResource, Thread): +class AsyncStudioThread(AsyncStudioResource, Threads): async def create( self, - messages: List[CreateMessagePayload], + messages: List[CreateThreadMessagePayload], **kwargs, - ) -> ThreadResponse: + ) -> Thread: body = dict(messages=messages) - return await self._post(path=f"/{self._module_name}", body=body, response_cls=ThreadResponse) + return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) - async def get(self, thread_id: str) -> ThreadResponse: - return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=ThreadResponse) + async def get(self, thread_id: str) -> Thread: + return await self._get(path=f"/{self._module_name}/{thread_id}", response_cls=Thread) diff --git a/ai21/models/assistant/__init__.py b/ai21/models/assistant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ai21/models/assistant/assistant.py b/ai21/models/assistant/assistant.py new file mode 100644 index 00000000..44c3ed06 --- /dev/null +++ b/ai21/models/assistant/assistant.py @@ -0,0 +1,12 @@ +from typing import Optional, Literal + +from typing_extensions import TypedDict + +Optimization = Literal["cost", "latency"] +Tool = Literal["rag", "internet_research", "plan_approval"] + + +class ToolResources(TypedDict, total=False): + rag: Optional[dict] + internet_research: Optional[dict] + plan_approval: Optional[dict] diff --git a/ai21/models/assistant/thread_message.py b/ai21/models/assistant/thread_message.py new file mode 100644 index 00000000..87ef50e3 --- /dev/null +++ b/ai21/models/assistant/thread_message.py @@ -0,0 +1,15 @@ +from typing import Literal + +from typing_extensions import TypedDict + +ThreadMessageRole = Literal["assistant", "user"] + + +class ThreadMessageContentText(TypedDict): + type: Literal["text"] + text: str + + +class CreateThreadMessagePayload(TypedDict): + role: ThreadMessageRole + content: ThreadMessageContentText diff --git a/ai21/models/responses/assistant_response.py b/ai21/models/responses/assistant_response.py index 000b0346..3263b5a8 100644 --- a/ai21/models/responses/assistant_response.py +++ b/ai21/models/responses/assistant_response.py @@ -1,22 +1,11 @@ from datetime import datetime from typing import Optional, List, Literal -from typing_extensions import TypedDict - from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.assistant import ToolResources -Optimization = Literal["cost", "latency"] -Tool = Literal["rag", "internet_research", "plan_approval"] - - -class ToolResources(TypedDict, total=False): - rag: Optional[dict] - internet_research: Optional[dict] - plan_approval: Optional[dict] - - -class AssistantResponse(AI21BaseModel): +class Assistant(AI21BaseModel): id: str created_at: datetime updated_at: datetime @@ -33,5 +22,5 @@ class AssistantResponse(AI21BaseModel): tool_resources: Optional[ToolResources] = None -class ListAssistantResponse(AI21BaseModel): - results: List[AssistantResponse] +class ListAssistant(AI21BaseModel): + results: List[Assistant] diff --git a/ai21/models/responses/thread_response.py b/ai21/models/responses/thread_response.py index 662047bb..6cb06443 100644 --- a/ai21/models/responses/thread_response.py +++ b/ai21/models/responses/thread_response.py @@ -1,41 +1,27 @@ from datetime import datetime from typing import Optional, List, Literal -from typing_extensions import TypedDict - from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.assistant.thread_message import ThreadMessageRole, ThreadMessageContentText -MessageRole = Literal["assistant", "user"] - - -class MessageContentText(TypedDict): - type: Literal["text"] - text: str - - -class CreateMessagePayload(TypedDict): - role: MessageRole - content: MessageContentText - - -class ThreadMessageResponse(AI21BaseModel): +class ThreadMessage(AI21BaseModel): id: str created_at: datetime updated_at: datetime object: Literal["message"] = "message" - role: MessageRole - content: MessageContentText + role: ThreadMessageRole + content: ThreadMessageContentText run_id: Optional[str] = None assistant_id: Optional[str] = None -class ThreadResponse(AI21BaseModel): +class Thread(AI21BaseModel): id: str created_at: datetime updated_at: datetime object: Literal["thread"] = "thread" -class ListThreadResponse(AI21BaseModel): - results: List[ThreadResponse] +class ListThread(AI21BaseModel): + results: List[Thread] From c944172571e10452b93542944794e6bcaeaff2c7 Mon Sep 17 00:00:00 2001 From: benshuk Date: Sun, 1 Dec 2024 16:04:01 +0200 Subject: [PATCH 3/3] fix: :truck: move classes and such --- ai21/clients/common/assistant/assistants.py | 8 ++--- ai21/clients/common/assistant/threads.py | 8 ++--- .../resources/assistant/studio_assistant.py | 8 ++--- .../resources/assistant/studio_thread.py | 15 +++------- ai21/models/assistant/message.py | 29 +++++++++++++++++++ ai21/models/assistant/thread_message.py | 15 ---------- ai21/models/responses/thread_response.py | 14 +-------- 7 files changed, 40 insertions(+), 57 deletions(-) create mode 100644 ai21/models/assistant/message.py delete mode 100644 ai21/models/assistant/thread_message.py diff --git a/ai21/clients/common/assistant/assistants.py b/ai21/clients/common/assistant/assistants.py index 0fd08d7b..e773f89a 100644 --- a/ai21/clients/common/assistant/assistants.py +++ b/ai21/clients/common/assistant/assistants.py @@ -3,12 +3,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List -from ai21.models.assistant.assistant import Optimization, Tool -from ai21.models.responses.assistant_response import ( - Assistant, - ToolResources, - ListAssistant, -) +from ai21.models.assistant.assistant import Optimization, Tool, ToolResources +from ai21.models.responses.assistant_response import Assistant, ListAssistant from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given diff --git a/ai21/clients/common/assistant/threads.py b/ai21/clients/common/assistant/threads.py index f68bc82b..5025ccc2 100644 --- a/ai21/clients/common/assistant/threads.py +++ b/ai21/clients/common/assistant/threads.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import List -from ai21.models.assistant.thread_message import CreateThreadMessagePayload +from ai21.models.assistant.message import Message from ai21.models.responses.thread_response import Thread @@ -11,11 +11,7 @@ class Threads(ABC): _module_name = "threads" @abstractmethod - def create( - self, - messages: List[CreateThreadMessagePayload], - **kwargs, - ) -> Thread: + def create(self, messages: List[Message], **kwargs) -> Thread: pass @abstractmethod diff --git a/ai21/clients/studio/resources/assistant/studio_assistant.py b/ai21/clients/studio/resources/assistant/studio_assistant.py index ec14758a..fa008695 100644 --- a/ai21/clients/studio/resources/assistant/studio_assistant.py +++ b/ai21/clients/studio/resources/assistant/studio_assistant.py @@ -7,12 +7,8 @@ AsyncStudioResource, StudioResource, ) -from ai21.models.responses.assistant_response import ( - Assistant, - Tool, - ToolResources, - ListAssistant, -) +from ai21.models.assistant.assistant import Tool, ToolResources +from ai21.models.responses.assistant_response import Assistant, ListAssistant from ai21.types import NotGiven, NOT_GIVEN diff --git a/ai21/clients/studio/resources/assistant/studio_thread.py b/ai21/clients/studio/resources/assistant/studio_thread.py index c8aac8cf..fd48563e 100644 --- a/ai21/clients/studio/resources/assistant/studio_thread.py +++ b/ai21/clients/studio/resources/assistant/studio_thread.py @@ -4,15 +4,12 @@ from ai21.clients.common.assistant.threads import Threads from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource -from ai21.models.responses.thread_response import CreateThreadMessagePayload, Thread +from ai21.models.assistant.message import Message +from ai21.models.responses.thread_response import Thread class StudioThread(StudioResource, Threads): - def create( - self, - messages: List[CreateThreadMessagePayload], - **kwargs, - ) -> Thread: + def create(self, messages: List[Message], **kwargs) -> Thread: body = dict(messages=messages) return self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) @@ -22,11 +19,7 @@ def get(self, thread_id: str) -> Thread: class AsyncStudioThread(AsyncStudioResource, Threads): - async def create( - self, - messages: List[CreateThreadMessagePayload], - **kwargs, - ) -> Thread: + async def create(self, messages: List[Message], **kwargs) -> Thread: body = dict(messages=messages) return await self._post(path=f"/{self._module_name}", body=body, response_cls=Thread) diff --git a/ai21/models/assistant/message.py b/ai21/models/assistant/message.py new file mode 100644 index 00000000..f3ec7d8d --- /dev/null +++ b/ai21/models/assistant/message.py @@ -0,0 +1,29 @@ +from datetime import datetime +from typing import Literal, Optional + +from typing_extensions import TypedDict + +from ai21.models.ai21_base_model import AI21BaseModel + +ThreadMessageRole = Literal["assistant", "user"] + + +class MessageContentText(TypedDict): + type: Literal["text"] + text: str + + +class Message(TypedDict): + role: ThreadMessageRole + content: MessageContentText + + +class MessageResponse(AI21BaseModel): + id: str + created_at: datetime + updated_at: datetime + object: Literal["message"] = "message" + role: ThreadMessageRole + content: MessageContentText + run_id: Optional[str] = None + assistant_id: Optional[str] = None diff --git a/ai21/models/assistant/thread_message.py b/ai21/models/assistant/thread_message.py deleted file mode 100644 index 87ef50e3..00000000 --- a/ai21/models/assistant/thread_message.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Literal - -from typing_extensions import TypedDict - -ThreadMessageRole = Literal["assistant", "user"] - - -class ThreadMessageContentText(TypedDict): - type: Literal["text"] - text: str - - -class CreateThreadMessagePayload(TypedDict): - role: ThreadMessageRole - content: ThreadMessageContentText diff --git a/ai21/models/responses/thread_response.py b/ai21/models/responses/thread_response.py index 6cb06443..b2c9bbc7 100644 --- a/ai21/models/responses/thread_response.py +++ b/ai21/models/responses/thread_response.py @@ -1,19 +1,7 @@ from datetime import datetime -from typing import Optional, List, Literal +from typing import List, Literal from ai21.models.ai21_base_model import AI21BaseModel -from ai21.models.assistant.thread_message import ThreadMessageRole, ThreadMessageContentText - - -class ThreadMessage(AI21BaseModel): - id: str - created_at: datetime - updated_at: datetime - object: Literal["message"] = "message" - role: ThreadMessageRole - content: ThreadMessageContentText - run_id: Optional[str] = None - assistant_id: Optional[str] = None class Thread(AI21BaseModel):