Skip to content

Commit

Permalink
wip - chat stream
Browse files Browse the repository at this point in the history
  • Loading branch information
ion2088 committed Feb 28, 2024
1 parent 4cbaf32 commit e14d7bd
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 19 deletions.
1 change: 1 addition & 0 deletions src/firedust/_utils/types/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel, field_serializer

UNIX_TIMESTAMP = float # see: https://www.unixtimestamp.com/
STREAM_STOP_EVENT = "[[STOP]]"


class BaseConfig(BaseModel):
Expand Down
37 changes: 35 additions & 2 deletions src/firedust/_utils/types/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Literal
from typing import Any, List, Literal
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, field_serializer

from ._base import UNIX_TIMESTAMP

Expand All @@ -14,3 +15,35 @@ class APIContent(BaseModel):
timestamp: UNIX_TIMESTAMP
data: Any = {}
message: str | None


class MessageStreamEvent(BaseModel, frozen=True):
"""
Represents a message stream event model.
"""

assistant_id: UUID
user_id: UUID | None = None
timestamp: UNIX_TIMESTAMP
message: str
stream_ended: bool
memory_refs: List[UUID] = []
conversation_refs: List[UUID] = []

@field_serializer("assistant_id", when_used="always")
def serialize_assistant_id(self, value: UUID) -> str:
return str(value)

@field_serializer("user_id", when_used="always")
def serialize_user_id(self, value: UUID | None) -> str | None:
if value is None:
return None
return str(value)

@field_serializer("memory_refs", when_used="always")
def serialize_memory_refs(self, value: List[UUID]) -> List[str]:
return [str(x) for x in value]

@field_serializer("conversation_refs", when_used="always")
def serialize_conversation_refs(self, value: List[UUID]) -> List[str]:
return [str(x) for x in value]
36 changes: 23 additions & 13 deletions src/firedust/_utils/types/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List
from uuid import UUID

from pydantic import BaseModel, field_serializer
from pydantic import field_serializer, field_validator

from ._base import UNIX_TIMESTAMP, BaseConfig
from .ability import AbilityConfig
Expand Down Expand Up @@ -31,6 +31,22 @@ class AssistantConfig(BaseConfig, frozen=True):
abilities: List[AbilityConfig] = []
interfaces: Interfaces = Interfaces()

@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str | Exception:
if len(name) > 50:
raise ValueError("Assistant name exceeds maximum length of 50 characters")
if len(name) < 1:
raise ValueError("Assistant name must be at least 1 character")
return name

@field_validator("instructions")
@classmethod
def validate_instructions(cls, instructions: str) -> str | Exception:
if len(instructions) < 20:
raise ValueError("Assistant instructions must be at least 20 characters")
return instructions


class UserMessage(BaseConfig, frozen=True):
"""
Expand Down Expand Up @@ -69,17 +85,17 @@ class AssistantMessage(BaseConfig, frozen=True):
response_to_id (UUID): The unique identifier of the message to which the assistant is responding.
message (str): The text of the message.
timestamp (UNIX_TIMESTAMP): The time when the message was sent.
context (str): The context of the message.
memory_refs (List[UUID], optional): The memory references of the message. Defaults to None.
conversation_refs (List[UUID], optional): The conversation references of the message. Defaults to None.
memory_refs (List[UUID], optional): The unique identifiers of the memories referenced by the message. Defaults to [].
conversation_refs (List[UUID], optional): The unique identifiers of the conversations referenced by the message. Defaults to [].
"""

assistant_id: UUID
user_id: UUID | None = None
response_to_id: UUID
timestamp: UNIX_TIMESTAMP
message: str
context: "Context"
memory_refs: List[UUID] = []
conversation_refs: List[UUID] = []

@field_serializer("assistant_id", when_used="always")
def serialize_assistant_id(self, value: UUID) -> str:
Expand All @@ -95,16 +111,10 @@ def serialize_user_id(self, value: UUID | None) -> str | None:
def serialize_response_to_id(self, value: UUID) -> str:
return str(value)


class Context(BaseModel):
instructions: str
memory_refs: List[UUID]
conversation_refs: List[UUID]

@field_serializer("memory_refs", when_used="always")
def serialize_memory_refs(self, value: List[UUID]) -> List[str]:
return [str(ref) for ref in value]
return [str(x) for x in value]

@field_serializer("conversation_refs", when_used="always")
def serialize_conversation_refs(self, value: List[UUID]) -> List[str]:
return [str(ref) for ref in value]
return [str(x) for x in value]
10 changes: 7 additions & 3 deletions src/firedust/interface/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
print(response)
"""

import json
from datetime import datetime
from typing import Iterator
from uuid import UUID, uuid4

from firedust._utils.api import APIClient
from firedust._utils.errors import APIError
from firedust._utils.types.api import MessageStreamEvent
from firedust._utils.types.assistant import AssistantConfig, UserMessage


Expand All @@ -43,7 +45,9 @@ def __init__(self, config: AssistantConfig, api_client: APIClient) -> None:
self.config = config
self.api_client = api_client

def stream(self, message: str, user_id: UUID | None = None) -> Iterator[bytes]:
def stream(
self, message: str, user_id: UUID | None = None
) -> Iterator[MessageStreamEvent]:
"""
Streams a conversation with the assistant.
Add a user id to keep chat histories separate for different users.
Expand All @@ -53,7 +57,7 @@ def stream(self, message: str, user_id: UUID | None = None) -> Iterator[bytes]:
user_id (UUID, optional): The unique identifier of the user. Defaults to None.
Yields:
Iterator[bytes]: The response from the assistant.
Iterator[MessageStreamEvent]: The response from the assistant.
"""
user_message = UserMessage(
id=uuid4(),
Expand All @@ -68,7 +72,7 @@ def stream(self, message: str, user_id: UUID | None = None) -> Iterator[bytes]:
"/chat/stream",
data=user_message.model_dump(),
):
yield msg
yield MessageStreamEvent(**json.loads(msg.decode("utf-8")))
except Exception as e:
raise APIError(f"Failed to stream the conversation: {e}")

Expand Down
6 changes: 5 additions & 1 deletion tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@
def test_chat_streaming() -> None:
assistant = Assistant.create()
response = assistant.chat.stream("Hi, how are you?")
x = 10

for x in response:
print(x.message)

w = 10

0 comments on commit e14d7bd

Please sign in to comment.