Skip to content

Commit

Permalink
wip - integrate and test the api interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ion2088 committed Feb 11, 2024
1 parent 728de88 commit 9925119
Show file tree
Hide file tree
Showing 17 changed files with 153 additions and 122 deletions.
3 changes: 2 additions & 1 deletion .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
"runArgs": [
"--network=host",
"-e", "DEV_NAME=${localEnv:DEV_NAME}",
"-e", "DEV_EMAIL=${localEnv:DEV_EMAIL}"
"-e", "DEV_EMAIL=${localEnv:DEV_EMAIL}",
"-e", "FIREDUST_API_KEY=${localEnv:FIREDUST_API_KEY}"
],
"mounts": [
"source=/home/ubuntu/.ssh,target=/root/.ssh,type=bind,consistency=cached"
Expand Down
51 changes: 17 additions & 34 deletions src/firedust/_utils/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import httpx

from firedust._utils.errors import APIError, MissingFiredustKeyError
from firedust._utils.errors import MissingFiredustKeyError

BASE_URL = "https://api.firedust.ai/v1"
# BASE_URL = "https://api.firedust.ai/v1"
BASE_URL = "http://0.0.0.0:8080"


class APIClient:
Expand Down Expand Up @@ -41,16 +42,16 @@ def __init__(self, api_key: str | None = None, base_url: str = BASE_URL) -> None
}

# sync methods
def get(self, url: str, params: Dict[str, Any] | None = None) -> Dict[str, Any]:
def get(self, url: str, params: Dict[str, Any] | None = None) -> httpx.Response:
return self._request_sync("get", url, params=params)

def post(self, url: str, data: Dict[str, Any] | None = None) -> Dict[str, Any]:
def post(self, url: str, data: Dict[str, Any] | None = None) -> httpx.Response:
return self._request_sync("post", url, data=data)

def put(self, url: str, data: Dict[str, Any] | None = None) -> Dict[str, Any]:
def put(self, url: str, data: Dict[str, Any] | None = None) -> httpx.Response:
return self._request_sync("put", url, data=data)

def delete(self, url: str) -> Dict[str, Any]:
def delete(self, url: str) -> httpx.Response:
return self._request_sync("delete", url)

def get_stream(
Expand All @@ -72,20 +73,20 @@ def post_stream(
# async methods
async def get_async(
self, url: str, params: Dict[str, Any] | None = None
) -> Dict[str, Any]:
) -> httpx.Response:
return await self._request_async("get", url, params=params)

async def post_async(
self, url: str, data: Dict[str, Any] | None = None
) -> Dict[str, Any]:
) -> httpx.Response:
return await self._request_async("post", url, data=data)

async def put_async(
self, url: str, data: Dict[str, Any] | None = None
) -> Dict[str, Any]:
) -> httpx.Response:
return await self._request_async("put", url, data=data)

async def delete_async(self, url: str) -> Dict[str, Any]:
async def delete_async(self, url: str) -> httpx.Response:
return await self._request_async("delete", url)

async def get_stream_async(
Expand Down Expand Up @@ -117,41 +118,23 @@ def _request_sync(
url: str,
params: Dict[str, Any] | None = None,
data: Dict[str, Any] | None = None,
) -> Dict[str, Any] | Any:
) -> httpx.Response:
url = self.base_url + url
response = httpx.request(
method, url, params=params, json=data, headers=self.headers
method, url, params=params, json=data, headers=self.headers, timeout=30
)
_handle_status_codes(response)

return response.json()
return response

async def _request_async(
self,
method: str,
url: str,
params: Dict[str, Any] | None = None,
data: Dict[str, Any] | None = None,
) -> Dict[str, Any] | Any:
) -> httpx.Response:
url = self.base_url + url
async with httpx.AsyncClient() as client:
response = await client.request(
method, url, params=params, json=data, headers=self.headers
method, url, params=params, json=data, headers=self.headers, timeout=30
)
_handle_status_codes(response)

return response.json()


def _handle_status_codes(response: httpx.Response) -> None:
if response.status_code == 400:
raise APIError("Bad Request", response.status_code)
elif response.status_code == 401:
raise APIError("Unauthorized", response.status_code)
elif response.status_code == 403:
raise APIError("Forbidden", response.status_code)
elif response.status_code == 404:
raise APIError("Not Found", response.status_code)
elif response.status_code == 500:
raise APIError("Internal Server Error", response.status_code)
# TODO: Customize status codes and error messages
return response
8 changes: 6 additions & 2 deletions src/firedust/_utils/types/_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, field_serializer

UNIX_TIMESTAMP = int | float
UNIX_TIMESTAMP = float # see: https://www.unixtimestamp.com/


class BaseConfig(BaseModel):
Expand All @@ -20,3 +20,7 @@ def __setattr__(self, key: str, value: Any) -> None:
raise AttributeError("Cannot set attribute 'id', it is immutable.")

return super().__setattr__(key, value)

@field_serializer("id", when_used="always")
def serialize_id(self, value: UUID) -> str:
return str(value)
4 changes: 0 additions & 4 deletions src/firedust/_utils/types/ability.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from abc import ABC
from typing import Any, List, Literal
from uuid import UUID, uuid4

from ._base import BaseConfig

ABILITY_ID = UUID


class AbilityConfig(BaseConfig, ABC):
"""
Represents a configuration for an ability.
"""

id: ABILITY_ID = uuid4()
name: str
description: str
instructions: List[str]
Expand Down
14 changes: 11 additions & 3 deletions src/firedust/_utils/types/assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List, Literal
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, field_serializer

from ._base import UNIX_TIMESTAMP, BaseConfig
from .ability import AbilityConfig
Expand All @@ -15,7 +15,6 @@ class AssistantConfig(BaseConfig):
Represents the configuration of an AI Assistant.
Args:
id (UUID): The unique identifier of the assistant.
name (str): The name of the assistant.
instructions (List[str]): The instructions of the assistant.
inference (InferenceConfig): The inference configuration of the assistant.
Expand All @@ -24,7 +23,6 @@ class AssistantConfig(BaseConfig):
deployments (List[Interface], optional): The deployments of the assistant. Defaults to None.
"""

id: UUID
name: str
instructions: List[str]
inference: InferenceConfig
Expand Down Expand Up @@ -62,3 +60,13 @@ class Message(BaseModel):
author: Literal["user", "assistant"]
text: str
timestamp: UNIX_TIMESTAMP

@field_serializer("assistant_id", when_used="always")
def serialize_assistant_id(self, value: UUID) -> str:
return str(value)

@field_serializer("user_id", when_used="always")
def serialize_user_id(self, value: UUID | None) -> str | None:
if value is None:
return None
return str(value)
45 changes: 41 additions & 4 deletions src/firedust/_utils/types/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, List, Literal
from uuid import UUID, uuid4

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_serializer, field_validator

from firedust._utils import checks

Expand Down Expand Up @@ -37,12 +37,14 @@ class MemoryItem(BaseConfig):
source: str | None = None

@field_validator("context")
@classmethod
def validate_context_length(cls, context: str) -> str | Exception:
if len(context) > 2000:
raise ValueError("Memory context exceeds maximum length of 2000 characters")
return context

@field_validator("timestamp")
@classmethod
def validate_timestamp(
cls, timestamp: UNIX_TIMESTAMP
) -> UNIX_TIMESTAMP | Exception:
Expand All @@ -53,23 +55,50 @@ def validate_timestamp(
raise ValueError(f"Invalid timestamp: {e}")
return timestamp

@field_serializer("collection", when_used="always")
def serialize_id(self, value: UUID) -> str:
return str(value)

class MemoriesCollectionItem(BaseConfig):

class MemoriesCollection(BaseConfig):
"""
Represents a collection of memories used by the assistant.
"""

collection: List[UUID] | None = None
collection_id: UUID
memory_ids: List[UUID] | None = None

def __setattr__(self, key: str, value: Any) -> None:
# set immutable attributes
if key == "collection_id":
raise AttributeError(
"""
Cannot set attribute 'collection_id', it is immutable.
To add a memory to the collection use assistant.memory.collection.add method.
"""
)

return super().__setattr__(key, value)

@field_serializer("collection_id", when_used="always")
def serialize_collection_id(self, value: UUID) -> str:
return str(value)

@field_serializer("memory_ids", when_used="always")
def serialize_memory_ids(self, value: List[UUID] | None) -> List[str] | None:
if value is None:
return None
return [str(memory_id) for memory_id in value]


class MemoryConfig(BaseModel):
"""
Configuration for Assistant's memory.
"""

default_collection: UUID = Field(default_factory=uuid4)
embedding_model: EMBEDDING_MODELS = "mistral-embed"
embedding_model_provider: EMBEDDING_PROVIDERS = "mistral"
default_collection: UUID = uuid4()
extra_collections: List[UUID] = []

def __setattr__(self, key: str, value: Any) -> None:
Expand All @@ -90,3 +119,11 @@ def __setattr__(self, key: str, value: Any) -> None:
)

return super().__setattr__(key, value)

@field_serializer("default_collection", when_used="always")
def serialize_default_collection(self, value: UUID) -> str:
return str(value)

@field_serializer("extra_collections", when_used="always")
def serialize_extra_collections(self, value: List[UUID]) -> List[str]:
return [str(collection) for collection in value]
2 changes: 1 addition & 1 deletion src/firedust/ability/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def write(self, instruction: str) -> str:
f"assistant/{self.config.id}/ability/code/write",
data={"instruction": instruction},
)
code: str = response["data"]["code"]
code: str = response.json()["data"]["code"]
return code

def execute(self, code: str) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/firedust/ability/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def create(self, config: CustomAbilityConfig) -> None:
"""
response = self.api_client.post(
f"assistants/{self.config.id}/abilities/custom/create",
data={"ability": config.model_dump_json()},
data={"ability": config.model_dump()},
)

if response["status_code"] != 200:
if response.status_code != 200:
raise CustomAbilityError("Failed to create custom ability.")

def execute(self, ability_id: str, instruction: str) -> str:
Expand All @@ -74,8 +74,8 @@ def execute(self, ability_id: str, instruction: str) -> str:
data={"id": ability_id, "instruction": instruction},
)

if response["status_code"] != 200:
if response.status_code != 200:
raise CustomAbilityError("Failed to execute custom ability.")

result: str = response["data"]["result"]
result: str = response.json()["data"]["result"]
return result
2 changes: 1 addition & 1 deletion src/firedust/ability/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ def solve(self, problem: str) -> str:
f"assistant/{self.config.id}/ability/math/solve",
data={"problem": problem},
)
solution: str = response["data"]["solution"]
solution: str = response.json()["data"]["solution"]

return solution
18 changes: 9 additions & 9 deletions src/firedust/agency/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def run(self, task: str) -> Dict[str, str]:
data={"task": task},
)

if response["status_code"] != 200:
raise Exception(response["message"])
if response.status_code != 200:
raise Exception(response.json()["message"])

metadata: Dict[str, str] = response
metadata: Dict[str, str] = response.json()

return metadata

Expand Down Expand Up @@ -89,10 +89,10 @@ def add(self, task: ScheduledTask) -> Dict[str, str]:
data={"task": task},
)

if response["status_code"] != 200:
raise Exception(response["message"])
if response.status_code != 200:
raise Exception(response.json()["message"])

metadata: Dict[str, str] = response
metadata: Dict[str, str] = response.json()

return metadata

Expand All @@ -107,10 +107,10 @@ def list(self) -> List[ScheduledTask]:
f"assistant/{self.config.id}/agency/task/schedule/list",
)

if response["status_code"] != 200:
raise Exception(response["message"])
if response.status_code != 200:
raise Exception(response.json()["message"])

scheduled_tasks: List[ScheduledTask] = [
ScheduledTask(**task) for task in response["result"]
ScheduledTask(**task) for task in response.json()["result"]
]
return scheduled_tasks
Loading

0 comments on commit 9925119

Please sign in to comment.