From dd05c83ce79d01a0c1296807cbe5e2b85deadfd0 Mon Sep 17 00:00:00 2001 From: Ion Date: Thu, 12 Sep 2024 10:59:35 +0000 Subject: [PATCH] add gpt4o + updates --- poetry.lock | 17 +++- pyproject.toml | 3 +- src/firedust/entrypoint/assistant.py | 20 ++-- src/firedust/types/chat.py | 7 +- src/firedust/utils/api.py | 140 ++++++++++++--------------- tests/assistant/test_chat.py | 57 ++++++----- 6 files changed, 126 insertions(+), 118 deletions(-) diff --git a/poetry.lock b/poetry.lock index 401f573..d9ea4aa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -532,6 +532,21 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tomli" version = "2.0.1" @@ -557,4 +572,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.8, <3.13" -content-hash = "41f23079f7ee2a01e32f4d7f4ad68605a5464cdb3c091b198b9b3a9b19cb89e0" +content-hash = "76b9158cdbceb545bfef6b0d3ef8ac988153a1bd993d189d18e1ef9fa32c17d6" diff --git a/pyproject.toml b/pyproject.toml index 4e73d9a..f931edf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "firedust" -version = "0.0.50" +version = "0.0.51" description = "A simple framework to build and deploy AI tools." authors = ["ion "] readme = "README.md" @@ -9,6 +9,7 @@ readme = "README.md" python = "^3.8, <3.13" pydantic = "^2.5.3" httpx = "^0.26.0" +tenacity = "^9.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.3.1" diff --git a/src/firedust/entrypoint/assistant.py b/src/firedust/entrypoint/assistant.py index 284c3f9..b3c0819 100644 --- a/src/firedust/entrypoint/assistant.py +++ b/src/firedust/entrypoint/assistant.py @@ -42,7 +42,7 @@ from firedust.types import APIContent, Assistant, AssistantConfig, AsyncAssistant from firedust.types.base import INFERENCE_MODEL -from firedust.utils.api import AsyncAPIClient, SyncAPIClient +from firedust.utils.api import SyncAPIClient, AsyncAPIClient from firedust.utils.errors import APIError from firedust.utils.logging import LOG @@ -50,15 +50,15 @@ def create( name: str, instructions: str = "", - model: INFERENCE_MODEL = "openai/gpt-4", + model: INFERENCE_MODEL = "openai/gpt-4o", ) -> Assistant: """ Creates a new assistant with the specified configuration. - Arg + Args: name (str): The name of the assistant. instructions (str): The instructions for the assistant. - model (INFERENCE_MODEL, optional): The inference model to use. Defaults to "mistral/mistral-medium". + model (INFERENCE_MODEL, optional): The inference model to use. Defaults to "openai/gpt-4". Returns: Assistant: A new instance of the assistant class. @@ -68,6 +68,7 @@ def create( response = api_client.post("/assistant", data=config.model_dump()) if not response.is_success: + api_client.close() raise APIError( code=response.status_code, message=f"Failed to create an assistant with config {config}: {response.text}", @@ -75,14 +76,13 @@ def create( LOG.info( f"Assistant {config.name} was created successfully and saved to the cloud." ) - return Assistant._create_instance(config, api_client) async def async_create( name: str, instructions: str = "", - model: INFERENCE_MODEL = "openai/gpt-4", + model: INFERENCE_MODEL = "openai/gpt-4o", ) -> AsyncAssistant: """ Asynchronously creates a new assistant with the specified configuration. @@ -90,7 +90,7 @@ async def async_create( Args: name (str): The name of the assistant. instructions (str): The instructions for the assistant. - model (INFERENCE_MODEL, optional): The inference model to use. Defaults to "mistral/mistral-medium". + model (INFERENCE_MODEL, optional): The inference model to use. Defaults to "openai/gpt-4". Returns: AsyncAssistant: A new instance of the AsyncAssistant class. @@ -100,6 +100,7 @@ async def async_create( response = await api_client.post("/assistant", data=config.model_dump()) if not response.is_success: + await api_client.close() raise APIError( code=response.status_code, message=f"Failed to create an assistant with config {config}: {response.text}", @@ -107,7 +108,6 @@ async def async_create( LOG.info( f"Assistant {config.name} was created successfully and saved to the cloud." ) - return await AsyncAssistant._create_instance(config, api_client) @@ -124,6 +124,7 @@ def load(name: str) -> Assistant: api_client = SyncAPIClient() response = api_client.get("/assistant", params={"name": name}) if not response.is_success: + api_client.close() raise APIError( code=response.status_code, message=f"Failed to load assistant {name}: {response.text}", @@ -146,6 +147,7 @@ async def async_load(name: str) -> AsyncAssistant: api_client = AsyncAPIClient() response = await api_client.get("/assistant", params={"name": name}) if not response.is_success: + await api_client.close() raise APIError( code=response.status_code, message=f"Failed to load the assistant with id {name}: {response.text}", @@ -165,6 +167,7 @@ def list() -> List[Assistant]: api_client = SyncAPIClient() response = api_client.get("/assistant/list") if not response.is_success: + api_client.close() raise APIError( code=response.status_code, message=f"Failed to list the assistants: {response.text}", @@ -184,6 +187,7 @@ async def async_list() -> List[AsyncAssistant]: api_client = AsyncAPIClient() response = await api_client.get("/assistant/list") if not response.is_success: + await api_client.close() raise APIError( code=response.status_code, message=f"Failed to list the assistants: {response.text}", diff --git a/src/firedust/types/chat.py b/src/firedust/types/chat.py index 5d849c7..e49f38d 100644 --- a/src/firedust/types/chat.py +++ b/src/firedust/types/chat.py @@ -30,7 +30,7 @@ class Message(BaseConfig, frozen=True): class UserMessage(Message): - author: Literal["user"] = Field(default="user", const=True) + author: Literal["user"] = Field(default="user") class StructuredUserMessage(UserMessage): @@ -38,7 +38,7 @@ class StructuredUserMessage(UserMessage): class AssistantMessage(Message): - author: Literal["assistant"] = Field(default="assistant", const=True) + author: Literal["assistant"] = Field(default="assistant") class MessageReferences(BaseModel): @@ -63,12 +63,11 @@ def serialize_conversation_refs(self, value: Sequence[UUID]) -> Sequence[str]: class StructuredAssistantMessage(BaseConfig, frozen=True): - # Unable to build of Assistant message because message is redefined assistant: str = Field(...) user: str = Field(...) timestamp: UNIX_TIMESTAMP = Field(...) message: STRUCTURED_RESPONSE = Field(...) - author: Literal["assistant"] = Field(default="assistant", const=True) + author: Literal["assistant"] = Field(default="assistant") references: Union[MessageReferences, None] = Field(default=None) diff --git a/src/firedust/utils/api.py b/src/firedust/utils/api.py index 49304ae..57628f2 100644 --- a/src/firedust/utils/api.py +++ b/src/firedust/utils/api.py @@ -1,35 +1,18 @@ import os -from typing import Any, AsyncIterator, Dict, Iterator, Optional +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Iterator, Optional, Type import httpx from firedust.utils.errors import MissingFiredustKeyError -BASE_URL = "https://api.firedust.dev" +BASE_URL = "http://localhost:3002" +# BASE_URL = "https://api.firedust.dev" TIMEOUT = 300 -class SyncAPIClient: - """ - A synchronous client for interacting with the Firedust API. - - Attributes: - base_url (str): The base URL of the Firedust API. - api_key (str): The API key used for authentication. - headers (Dict[str, str]): The headers to be included in the requests. - """ - +class BaseAPIClient: def __init__(self, api_key: Optional[str] = None, base_url: str = BASE_URL) -> None: - """ - Initializes a new instance of the SyncAPIClient class. - - Args: - api_key (str, optional): The API key to authenticate requests. If not provided, it will be fetched from the environment variable "FIREDUST_API_KEY". Defaults to None. - base_url (str, optional): The base URL of the Firedust API. Defaults to BASE_URL. - - Raises: - MissingFiredustKeyError: If the API key is not provided and not found in the environment variable. - """ api_key = api_key or os.environ.get("FIREDUST_API_KEY") if not api_key: raise MissingFiredustKeyError() @@ -41,6 +24,26 @@ def __init__(self, api_key: Optional[str] = None, base_url: str = BASE_URL) -> N "Authorization": f"Bearer {api_key}", } + +class SyncAPIClient(BaseAPIClient): + def __init__(self, api_key: Optional[str] = None, base_url: str = BASE_URL) -> None: + super().__init__(api_key, base_url) + self.client: httpx.Client = httpx.Client(timeout=TIMEOUT, headers=self.headers) + + def __enter__(self) -> "SyncAPIClient": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.client.close() + + def __del__(self) -> None: + self.client.close() + def get(self, url: str, params: Optional[Dict[str, Any]] = None) -> httpx.Response: return self._request("get", url, params=params) @@ -59,9 +62,7 @@ def get_stream( self, url: str, params: Optional[Dict[str, Any]] = None ) -> Iterator[bytes]: url = self.base_url + url - with httpx.stream( - "get", url, params=params, headers=self.headers, timeout=TIMEOUT - ) as response: + with self.client.stream("get", url, params=params) as response: for chunk in response.iter_bytes(): yield chunk @@ -69,9 +70,7 @@ def post_stream( self, url: str, data: Optional[Dict[str, Any]] = None ) -> Iterator[bytes]: url = self.base_url + url - with httpx.stream( - "post", url, json=data, headers=self.headers, timeout=TIMEOUT - ) as response: + with self.client.stream("post", url, json=data) as response: for chunk in response.iter_bytes(): yield chunk @@ -83,43 +82,34 @@ def _request( data: Optional[Dict[str, Any]] = None, ) -> httpx.Response: url = self.base_url + url - response = httpx.request( - method, url, params=params, json=data, headers=self.headers, timeout=TIMEOUT - ) + response = self.client.request(method, url, params=params, json=data) return response + def close(self) -> None: + """ + Close the underlying HTTP client. + """ + self.client.close() -class AsyncAPIClient: - """ - An asynchronous client for interacting with the Firedust API. - - Attributes: - base_url (str): The base URL of the Firedust API. - api_key (str): The API key used for authentication. - headers (Dict[str, str]): The headers to be included in the requests. - """ +class AsyncAPIClient(BaseAPIClient): def __init__(self, api_key: Optional[str] = None, base_url: str = BASE_URL) -> None: - """ - Initializes a new instance of the AsyncAPIClient class. + super().__init__(api_key, base_url) + self.client = httpx.AsyncClient(timeout=TIMEOUT, headers=self.headers) - Args: - api_key (str, optional): The API key to authenticate requests. If not provided, it will be fetched from the environment variable "FIREDUST_API_KEY". Defaults to None. - base_url (str, optional): The base URL of the Firedust API. Defaults to BASE_URL. + async def __aenter__(self) -> "AsyncAPIClient": + return self - Raises: - MissingFiredustKeyError: If the API key is not provided and not found in the environment variable. - """ - api_key = api_key or os.environ.get("FIREDUST_API_KEY") - if not api_key: - raise MissingFiredustKeyError() + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.client.aclose() - self.base_url = base_url - self.api_key = api_key - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - } + async def __del__(self) -> None: + await self.client.aclose() async def get( self, url: str, params: Optional[Dict[str, Any]] = None @@ -145,23 +135,17 @@ async def get_stream( self, url: str, params: Optional[Dict[str, Any]] = None ) -> AsyncIterator[bytes]: url = self.base_url + url - async with httpx.AsyncClient() as client: - async with client.stream( - "get", url, params=params, headers=self.headers, timeout=TIMEOUT - ) as response: - async for chunk in response.aiter_bytes(): - yield chunk + async with self.client.stream("get", url, params=params) as response: + async for chunk in response.aiter_bytes(): + yield chunk async def post_stream( self, url: str, data: Optional[Dict[str, Any]] = None ) -> AsyncIterator[bytes]: url = self.base_url + url - async with httpx.AsyncClient() as client: - async with client.stream( - "post", url, json=data, headers=self.headers, timeout=TIMEOUT - ) as response: - async for chunk in response.aiter_bytes(): - yield chunk + async with self.client.stream("post", url, json=data) as response: + async for chunk in response.aiter_bytes(): + yield chunk async def _request( self, @@ -171,13 +155,11 @@ async def _request( data: Optional[Dict[str, Any]] = None, ) -> httpx.Response: url = self.base_url + url - async with httpx.AsyncClient() as client: - response = await client.request( - method, - url, - params=params, - json=data, - headers=self.headers, - timeout=TIMEOUT, - ) - return response + response = await self.client.request(method, url, params=params, json=data) + return response + + async def close(self) -> None: + """ + Close the underlying HTTP client. + """ + await self.client.aclose() diff --git a/tests/assistant/test_chat.py b/tests/assistant/test_chat.py index fb798b9..e4f8498 100644 --- a/tests/assistant/test_chat.py +++ b/tests/assistant/test_chat.py @@ -330,7 +330,7 @@ async def test_async_chat_structured_simple() -> None: assert isinstance(response.message["name"], str) assert response.message["name"] == "Jane Smith" assert isinstance(response.message["occupation"], str) - assert response.message["occupation"] == "data scientist" + assert response.message["occupation"].lower() == "data scientist" finally: await assistant.delete(confirm=True) @@ -363,22 +363,8 @@ def test_chat_structured_complex_message() -> None: assistant = firedust.assistant.create(f"test-assistant-{random.randint(1, 1000)}") schema: STRUCTURED_SCHEMA = { - "company": DictField( - hint="The details of the company.", - items={ - "name": StringField(hint="The name of the company."), - "description": StringField(hint="A brief description of the company."), - "values": DictField( - hint="The values of the company.", - items={ - "name": StringField(hint="The name of the value."), - "description": StringField( - hint="A brief description of the value." - ), - }, - ), - }, - ), + "name": StringField(hint="The name of the company."), + "description": StringField(hint="A brief description of the company."), "products": ListField( hint="The products and services offered by the company.", items=StringField(hint="The product or service name."), @@ -392,9 +378,21 @@ def test_chat_structured_complex_message() -> None: "energy", ], ), - "weapons": BooleanField(hint="Whether the company produces weapons."), - "employees": FloatField( - hint="The estimated number of employees, assumed from context." + "weapons": BooleanField( + hint="True if the company is involved in weapons production." + ), + "employees": FloatField(hint="Approximate number of employees in the company."), + "values": ListField( + hint="The values of the company.", + items=DictField( + hint="The name and description of a value of the company.", + items={ + "name": StringField(hint="The name of the value."), + "description": StringField( + hint="A brief description of the value." + ), + }, + ), ), } @@ -409,16 +407,25 @@ def test_chat_structured_complex_message() -> None: # Check the response assert isinstance(response, StructuredAssistantMessage) assert schema.keys() == response.message.keys() - - assert isinstance(response.message["company"], dict) - assert isinstance(response.message["company"]["name"], str) - assert response.message["company"]["name"].lower() == "general dynamics" + assert isinstance(response.message["name"], str) + assert response.message["name"].lower() == "general dynamics" + assert isinstance(response.message["description"], str) + assert isinstance(response.message["values"], list) + for value in response.message["values"]: + assert isinstance(value, dict) + assert isinstance(value["name"], str) + assert isinstance(value["description"], str) assert isinstance(response.message["products"], list) assert all(isinstance(product, str) for product in response.message["products"]) assert isinstance(response.message["industry"], str) - assert "defense" == response.message["industry"].lower() + assert response.message["industry"].lower() in [ + "agriculture", + "defense", + "chemicals", + "energy", + ] assert isinstance(response.message["weapons"], bool) assert response.message["weapons"] is True