From 51c8114d6b020be689e27ff0d8b4af1642c2c3c8 Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Thu, 7 Nov 2024 17:27:23 -0800 Subject: [PATCH 01/14] refactored DAGNode and Project to use a dictionary for nodes instead of a set, basic guardrailless endpoints for project execution --- .../concrete/projects/dag_project.py | 53 +++++++++------ webapp/api/server.py | 64 ++++++++++++------- 2 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index bb7a3b61..323e8d73 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -16,18 +16,18 @@ def __init__( self, options: dict = {}, ) -> None: - self.edges: dict[DAGNode, list[tuple[DAGNode, str, Callable]]] = defaultdict(list) + self.edges: dict[str, list[tuple[str, str, Callable]]] = defaultdict(list) self.options = options - self.nodes: set[DAGNode] = set() + self.nodes: dict[str, DAGNode] = {} def add_edge( self, - child: "DAGNode", - parent: "DAGNode", + parent: str, + child: str, res_name: str, res_transformation: Callable = lambda x: x, - ) -> None: + ) -> tuple[str, str, str]: """ child: Downstream node parent: Upstream node @@ -40,8 +40,17 @@ def add_edge( self.edges[parent].append((child, res_name, res_transformation)) - def add_node(self, node: "DAGNode") -> None: - self.nodes.add(node) + return (parent, child, res_name) + + def add_node(self, name: str, node: "DAGNode") -> "DAGNode": + if name != node.name: + node.name = name + if node.name == "" or node.name in self.nodes: + # TODO: implement random name generator bandit is happy with + # https://www.geeksforgeeks.org/python-generate-random-string-of-given-length/ does not fly + node.name = max(self.nodes, default="") + "1" + self.nodes[node.name] = node + return node async def execute(self) -> AsyncGenerator[tuple[str, str], None]: if not self.is_dag: @@ -55,28 +64,28 @@ async def execute(self) -> AsyncGenerator[tuple[str, str], None]: while no_dep_nodes: ready_node = no_dep_nodes.pop() - operator_name, res = await ready_node.execute(self.options) + operator_name, res = await self.nodes[ready_node].execute(self.options) yield (operator_name, res) for child, res_name, res_transformation in self.edges[ready_node]: - child.update(res_transformation(res), res_name) + self.nodes[child].update(res_name, res_transformation(res)) node_dep_count[child] -= 1 if node_dep_count[child] == 0: no_dep_nodes.add(child) @property - def is_dag(self): + def is_dag(self) -> bool: # AI generated - visited = set() - rec_stack = set() + visited: set[str] = set() + rec_stack: set[str] = set() - def dfs(node: DAGNode) -> bool: + def dfs(node: str) -> bool: if node not in visited: visited.add(node) rec_stack.add(node) - for child, _, _ in self.edges.get(node, []): + for child, _, _ in self.edges[node]: if child not in visited: if not dfs(child): return False @@ -102,7 +111,8 @@ class DAGNode: def __init__( self, - task: str, + name: str, + boost: str, operator: Operator, default_task_kwargs: dict[str, Any] = {}, options: dict[str, Any] = {}, @@ -114,17 +124,18 @@ def __init__( options: Maps to OperatorOptions. Can also be set in default_task_kwargs as {'options': {...}} """ try: - self.bound_task = getattr(operator, task) + self.bound_task = getattr(operator, boost) except AttributeError: - raise ValueError(f"{operator} does not have a method {task}") + raise ValueError(f"{operator} does not have a method {boost}") self.operator: Operator = operator - self.task_str = task + self.name = name + self.boost_str = boost self.dynamic_kwargs: dict[str, Any] = {} self.default_task_kwargs = default_task_kwargs # TODO probably want to manage this in the project self.options = options # Could also throw this into default_task_kwargs - def update(self, dyn_kwarg_value, dyn_kwarg_name) -> None: + def update(self, dyn_kwarg_name, dyn_kwarg_value) -> None: self.dynamic_kwargs[dyn_kwarg_name] = dyn_kwarg_value async def execute(self, options: dict = {}) -> Any: @@ -137,7 +148,7 @@ async def execute(self, options: dict = {}) -> Any: if options.get("run_async"): res = res.get().message - return type(self.operator).__name__, res + return self.name, res def __str__(self): - return f"{type(self.operator).__name__}.{self.task_str}(**{self.default_task_kwargs})" + return f"{type(self.operator).__name__}.{self.boost_str}(**{self.default_task_kwargs})" diff --git a/webapp/api/server.py b/webapp/api/server.py index eda0248a..184c39ce 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -1,6 +1,6 @@ import os from collections.abc import Callable, Sequence -from typing import Annotated +from typing import Annotated, Any from uuid import UUID import dotenv @@ -257,7 +257,7 @@ def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> @app.post("/build/project") -def init_project(name: str): +def initialize_project(name: str) -> str: """ Initiate a directed-acyclic-graph (DAG) project locally. Projects must be unique in name. @@ -273,36 +273,56 @@ def init_project(name: str): if name in PROJECTS: raise HTTPException(status_code=400, detail="{name} already exists as a Project!") PROJECTS[name] = Project() - return PROJECTS[name] + return name -@app.post("/build/project/{project_name}/node") -def expand_project_with_method(project_name: str, operator_name: str, task: str): +@app.post("/build/project/{project}/task") +def expand_project_with_task( + project: str, operator: str, boost: str = "chat", task: str = "", default_task_kwargs: dict[str, Any] = {} +) -> str: """ Expand a project by adding an operator task as a node in its DAG. - project_name: The name of the project to be expanded. - operator_name: The name of the operator whose task we'd like to use. - task: The name of the operator's task to add as a node. + project: The name of the project to be expanded. + operator: The name of the operator whose task we'd like to use. + boost: The name of the operator's prompt boost to add as a node. + task: The name of the task this node represents. + default_task_kwargs: Any default arguments to pass to the task. """ - if project_name not in PROJECTS: - raise project_not_found(project_name) - project = PROJECTS[project_name] - node = DAGNode(task, getattr(operators, operator_name)) - project.add_node(node) - return project + if project not in PROJECTS: + raise project_not_found(project) + project_obj = PROJECTS[project] + node = DAGNode(task, boost, getattr(operators, operator)(), default_task_kwargs) + return project_obj.add_node(task, node).name -@app.post("/build/project/{project_name}/edge") -def expand_project_with_connection(project_name: str, parent_name: str, child_name: str): +@app.post("/build/project/{project}/edge") +def expand_project_with_connection(project: str, parent: str, child: str, input_to_child: str) -> tuple[str, str, str]: """ - WIP. Expand a project by connecting two tasks together. The output from the parent task will be fed into the child task. - project_name: The name of the project to be expanded. - parent_name: The name of the parent task in the connection. - child_name: The name of the child task in the connection. + project: The name of the project to be expanded. + parent: The name of the parent task in the connection. + child: The name of the child task in the connection. + input_to_child: The name of the input to the child (equivalently, the output from the parent) """ - if project_name not in PROJECTS: - raise project_not_found(project_name) + if project not in PROJECTS: + raise project_not_found(project) + project_obj = PROJECTS[project] + return project_obj.add_edge(parent, child, input_to_child) + + +@app.post("/build/project/{project}/run") +async def run_project(project: str): + """ + Run a project from its sources to its sinks. + + project: The name of the project to be run. + """ + if project not in PROJECTS: + raise project_not_found(project) + project_obj = PROJECTS[project] + async for operator, response in project_obj.execute(): + print(operator) + print(response) From 3d277e3d45eef90e85060783e06c65d4244aa025 Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 13 Nov 2024 02:05:07 -0800 Subject: [PATCH 02/14] WIP DB integration --- .../concrete/projects/__init__.py | 2 - .../concrete/projects/dag_project.py | 8 +-- src/concrete-db/concrete_db/crud.py | 53 ++++++++++++--- src/concrete-db/concrete_db/orm/models.py | 64 +++++++++++++++++++ webapp/api/server.py | 19 ++++-- 5 files changed, 126 insertions(+), 20 deletions(-) diff --git a/src/concrete-core/concrete/projects/__init__.py b/src/concrete-core/concrete/projects/__init__.py index 95abd7fa..e1bd6269 100644 --- a/src/concrete-core/concrete/projects/__init__.py +++ b/src/concrete-core/concrete/projects/__init__.py @@ -1,6 +1,4 @@ from .dag_project import DAGNode, Project from .software_project import SoftwareProject -PROJECTS: dict[str, Project] = {} - __all__ = ["DAGNode", "Project", "SoftwareProject"] diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index 323e8d73..e1f1dccb 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -112,7 +112,7 @@ class DAGNode: def __init__( self, name: str, - boost: str, + task: str, operator: Operator, default_task_kwargs: dict[str, Any] = {}, options: dict[str, Any] = {}, @@ -124,13 +124,13 @@ def __init__( options: Maps to OperatorOptions. Can also be set in default_task_kwargs as {'options': {...}} """ try: - self.bound_task = getattr(operator, boost) + self.bound_task = getattr(operator, task) except AttributeError: - raise ValueError(f"{operator} does not have a method {boost}") + raise ValueError(f"{operator} does not have a method {task}") self.operator: Operator = operator self.name = name - self.boost_str = boost + self.boost_str = task self.dynamic_kwargs: dict[str, Any] = {} self.default_task_kwargs = default_task_kwargs # TODO probably want to manage this in the project self.options = options # Could also throw this into default_task_kwargs diff --git a/src/concrete-db/concrete_db/crud.py b/src/concrete-db/concrete_db/crud.py index 08f83a19..62ace387 100644 --- a/src/concrete-db/concrete_db/crud.py +++ b/src/concrete-db/concrete_db/crud.py @@ -13,6 +13,12 @@ Client, ClientCreate, ClientUpdate, + DagNode, + DagNodeCreate, + DagNodeToDagNodeLink, + DagProject, + DagProjectCreate, + DagProjectUpdate, Message, MessageCreate, MessageUpdate, @@ -73,7 +79,7 @@ def delete_generic(db: Session, model: M | None) -> M | None: return model -# ===Operator=== # +# region Operator CRUD # TODO: automate project creation via DML trigger/event @@ -151,7 +157,8 @@ def delete_operator(db: Session, operator_id: UUID, orchestrator_id: UUID) -> Op ) -# ===Client=== # +# endregion +# region Client CRUD def create_client(db: Session, client_create: ClientCreate) -> Client: @@ -218,7 +225,8 @@ def delete_client( ) -# ===Tool=== # +# endregion +# region Tool CRUD def create_tool(db: Session, tool_create: ToolCreate, user_id: UUID) -> Tool: @@ -287,7 +295,8 @@ def assign_tool_to_operator(db: Session, operator_id: UUID, tool_id: UUID) -> Op ) -# ===Message=== # +# endregion +# region Message CRUD def create_message(db: Session, message_create: MessageCreate) -> Message: @@ -356,7 +365,8 @@ def delete_message( return delete_generic(db, get_message(db, message_id)) -# ===Orchestrator=== # +# endregion +# region Orchestrator CRUD def create_orchestrator(db: Session, orchestrator_create: OrchestratorCreate) -> Orchestrator: @@ -409,7 +419,8 @@ def delete_orchestrator(db: Session, orchestrator_id: UUID, user_id: UUID | None ) -# ===Project=== # +# endregion +# region Project CRUD def create_project(db: Session, project_create: ProjectCreate) -> Project: @@ -463,7 +474,27 @@ def delete_project(db: Session, project_id: UUID, orchestrator_id: UUID) -> Proj ) -# ===Node=== # +# endregion +# region DagProject CRUD + + +def create_dag_project(db: Session, dag_project_create: DagProjectCreate) -> DagProject: + return create_generic(db, DagProject(**dag_project_create.model_dump())) + + +def create_dag_node(db: Session, dag_node_create: DagNodeCreate) -> DagNode: + return create_generic(db, DagNode(**dag_node_create.model_dump())) + + +def get_dag_project_by_name(db: Session, name: str) -> DagProject | None: + stmt = select(DagProject).where(DagProject.name == name) + return db.scalars(stmt).first() + + +# endregion +# region Node CRUD + + def create_node(db: Session, node_create: NodeCreate) -> Node: return create_generic(db, Node(**node_create.model_dump())) @@ -495,7 +526,10 @@ def get_repo_node_by_path(db: Session, org: str, repo: str, abs_path: str, branc return db.scalars(stmt).first() -# ===User Auth=== # +# endregion +# region Auth CRUD + + def create_authstate(db: Session, authstate_create: AuthStateCreate) -> AuthState: return create_generic( db, @@ -522,3 +556,6 @@ def get_user(db: Session, email: str) -> User | None: def create_authtoken(db: Session, authtoken_create: AuthTokenCreate) -> AuthToken: return create_generic(db, AuthToken(**authtoken_create.model_dump())) + + +# endregion diff --git a/src/concrete-db/concrete_db/orm/models.py b/src/concrete-db/concrete_db/orm/models.py index 6af10dca..aa3bbfc2 100644 --- a/src/concrete-db/concrete_db/orm/models.py +++ b/src/concrete-db/concrete_db/orm/models.py @@ -65,6 +65,14 @@ class UserToolLink(Base, table=True): tool_id: UUID = Field(foreign_key="tool.id", primary_key=True, ondelete="CASCADE") +class DagNodeToDagNodeLink(Base, table=True): + parent: UUID = Field(foreign_key="dagnode.id", primary_key=True, index=True, ondelete="CASCADE") + child: UUID = Field(foreign_key="dagnode.id", primary_key=True, ondelete="CASCADE") + kwarg_name: str = Field(description="Name of the argument to the child task") + + # TODO maybe store transformation function + + # region User Models @@ -326,6 +334,62 @@ class Project(ProjectBase, MetadataMixin, table=True): messages: list["Message"] = Relationship(back_populates="project", cascade_delete=True) +# endregion +# region DagProject Models + + +class DagProjectBase(Base): + name: str = Field( + description="Name of the project.", + unique=True, + max_length=64, + ) + + +class DagProjectUpdate(DagProjectBase): + name: str | None = Field(description="Name of the project.", max_length=64, default=None) + + +class DagProjectCreate(DagProjectBase): + pass + + +class DagProject(DagProjectBase, MetadataMixin, table=True): + edges: list["DagNodeToDagNodeLink"] = Relationship(back_populates="project", cascade_delete=True) + nodes: list["DagNode"] = Relationship(back_populates="project", cascade_delete=True) + + +class DagNodeBase(Base): + name: str = Field( + description="Name of the DAG Project Node.", + unique=True, + max_length=64, + ) + task: str = Field(description="Name of method on Operator (e.g. 'chat')") + operator_id: UUID = Field( + description="ID of Operator encapsulated by this DAG Node.", + foreign_key="operator.id", + ondelete="CASCADE", + ) + default_task_kwargs: str = Field( + description="Default kwargs for the task as JSON.", + default="{}", + ) + # TODO: options + + +class DagNodeUpdate(Base): + pass + + +class DagNodeCreate(DagNodeBase): + pass + + +class DagNode(DagNodeBase, MetadataMixin, table=True): + operator: Operator = Relationship() + + # endregion # region Tool Models diff --git a/webapp/api/server.py b/webapp/api/server.py index 184c39ce..e51b60b3 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -4,7 +4,7 @@ from uuid import UUID import dotenv -from concrete.projects import PROJECTS, DAGNode, Project +from concrete.projects import DAGNode, Project from concrete.webutils import AuthMiddleware from concrete_db import crud from concrete_db.orm import Session @@ -12,6 +12,7 @@ Client, ClientCreate, ClientUpdate, + DagProjectCreate, Operator, OperatorCreate, OperatorUpdate, @@ -257,7 +258,7 @@ def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> @app.post("/build/project") -def initialize_project(name: str) -> str: +def initialize_project(dag_project_create: DagProjectCreate) -> str: """ Initiate a directed-acyclic-graph (DAG) project locally. Projects must be unique in name. @@ -270,10 +271,14 @@ def initialize_project(name: str) -> str: name: The name of the project to be initialized. """ - if name in PROJECTS: - raise HTTPException(status_code=400, detail="{name} already exists as a Project!") - PROJECTS[name] = Project() - return name + name = dag_project_create.name + with Session() as session: + db_project = crud.get_dag_project_by_name(session, name) + if db_project is not None: + raise HTTPException(status_code=400, detail="{name} already exists as a Project!") + db_project = crud.create_dag_project(session, dag_project_create) + + return db_project @app.post("/build/project/{project}/task") @@ -289,6 +294,8 @@ def expand_project_with_task( task: The name of the task this node represents. default_task_kwargs: Any default arguments to pass to the task. """ + with Session() as session: + db_project = crud.get_dag_project_by_name(session) if project not in PROJECTS: raise project_not_found(project) project_obj = PROJECTS[project] From 052cec9a880bef91e15ca52a7f21619f81276ea6 Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 13 Nov 2024 15:44:19 -0800 Subject: [PATCH 03/14] testless, first iteration of DAGProject API --- src/concrete-db/concrete_db/crud.py | 44 ++++- src/concrete-db/concrete_db/orm/models.py | 69 ++++++-- webapp/api/server.py | 200 +++++++++++++++++----- 3 files changed, 250 insertions(+), 63 deletions(-) diff --git a/src/concrete-db/concrete_db/crud.py b/src/concrete-db/concrete_db/crud.py index 62ace387..eff75e12 100644 --- a/src/concrete-db/concrete_db/crud.py +++ b/src/concrete-db/concrete_db/crud.py @@ -14,11 +14,10 @@ ClientCreate, ClientUpdate, DagNode, - DagNodeCreate, + DagNodeBase, DagNodeToDagNodeLink, DagProject, DagProjectCreate, - DagProjectUpdate, Message, MessageCreate, MessageUpdate, @@ -45,6 +44,9 @@ UserToolLink, ) +# region Generic Utils + + M = TypeVar("M", bound=Base) N = TypeVar("N", bound=Base) @@ -79,6 +81,7 @@ def delete_generic(db: Session, model: M | None) -> M | None: return model +# endregion # region Operator CRUD @@ -482,15 +485,50 @@ def create_dag_project(db: Session, dag_project_create: DagProjectCreate) -> Dag return create_generic(db, DagProject(**dag_project_create.model_dump())) -def create_dag_node(db: Session, dag_node_create: DagNodeCreate) -> DagNode: +def create_dag_node(db: Session, dag_node_create: DagNodeBase) -> DagNode: return create_generic(db, DagNode(**dag_node_create.model_dump())) +def create_dag_edge(db: Session, dag_edge: DagNodeToDagNodeLink) -> DagNodeToDagNodeLink: + return create_generic(db, dag_edge) + + +def get_dag_projects( + db: Session, + skip: int = 0, + limit: int = 100, +) -> Sequence[DagProject]: + stmt = select(DagProject).offset(skip).limit(limit) + return db.scalars(stmt).all() + + def get_dag_project_by_name(db: Session, name: str) -> DagProject | None: stmt = select(DagProject).where(DagProject.name == name) return db.scalars(stmt).first() +def get_dag_node_by_name(db: Session, project_id: UUID, node_name: str) -> DagNode | None: + stmt = select(DagNode).where(DagNode.project_id == project_id).where(DagNode.name == node_name) + return db.scalars(stmt).first() + + +def get_dag_edge(db: Session, project_name: str, parent_name: str, child_name: str) -> DagNodeToDagNodeLink | None: + stmt = ( + select(DagNodeToDagNodeLink) + .where(DagNodeToDagNodeLink.project_name == project_name) + .where(DagNodeToDagNodeLink.parent_name == parent_name) + .where(DagNodeToDagNodeLink.child_name == child_name) + ) + return db.scalars(stmt).first() + + +def delete_dag_project_by_name(db: Session, name: str) -> DagProject | None: + return delete_generic( + db, + get_dag_project_by_name(db, name), + ) + + # endregion # region Node CRUD diff --git a/src/concrete-db/concrete_db/orm/models.py b/src/concrete-db/concrete_db/orm/models.py index aa3bbfc2..f677bbe8 100644 --- a/src/concrete-db/concrete_db/orm/models.py +++ b/src/concrete-db/concrete_db/orm/models.py @@ -46,7 +46,7 @@ class ProfilePictureMixin(SQLModel): ) # TODO: probably use urllib here, oos -# Relationship Models +# region Link Models # TODO for all Link models: Drop index on id, replace with semantic primary key @@ -66,13 +66,17 @@ class UserToolLink(Base, table=True): class DagNodeToDagNodeLink(Base, table=True): - parent: UUID = Field(foreign_key="dagnode.id", primary_key=True, index=True, ondelete="CASCADE") - child: UUID = Field(foreign_key="dagnode.id", primary_key=True, ondelete="CASCADE") - kwarg_name: str = Field(description="Name of the argument to the child task") + project_name: str = Field(foreign_key="dagproject.name", primary_key=True, index=True, ondelete="CASCADE") + parent_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") + child_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") + input_to_child: str = Field(description="Name of the argument to the child task") + + project: "DagProject" = Relationship(back_populates="edges") # TODO maybe store transformation function +# endregion # region User Models @@ -346,7 +350,7 @@ class DagProjectBase(Base): ) -class DagProjectUpdate(DagProjectBase): +class DagProjectUpdate(Base): name: str | None = Field(description="Name of the project.", max_length=64, default=None) @@ -360,21 +364,35 @@ class DagProject(DagProjectBase, MetadataMixin, table=True): class DagNodeBase(Base): + project_id: UUID = Field( + description="ID of DAG Project this DAG Node belongs to.", + foreign_key="dagproject.id", + ondelete="CASCADE", + ) name: str = Field( description="Name of the DAG Project Node.", - unique=True, max_length=64, ) - task: str = Field(description="Name of method on Operator (e.g. 'chat')") - operator_id: UUID = Field( - description="ID of Operator encapsulated by this DAG Node.", - foreign_key="operator.id", - ondelete="CASCADE", + + # TODO: incorporate orchestrator OR decouple + operator_name: str = Field( + description="Name of Operator encapsulated by this DAG Node.", + max_length=64, + default="Operator", + ) + task_name: str = Field( + description="Name of method on Operator (e.g. 'chat')", + max_length=64, + default="chat", ) + default_task_kwargs: str = Field( description="Default kwargs for the task as JSON.", default="{}", ) + + __table_args__ = (UniqueConstraint("name", "project_id", name="no_duplicate_names_per_project"),) + # TODO: options @@ -382,12 +400,35 @@ class DagNodeUpdate(Base): pass -class DagNodeCreate(DagNodeBase): - pass +class DagNodeCreate(Base): + project_name: str = Field( + description="Name of the DAG Project this DAG Node belongs to.", + foreign_key="dagproject.name", + max_length=64, + ) + name: str = Field( + description="Name of the DAG Project Node.", + max_length=64, + ) + + operator_name: str = Field( + description="Name of Operator encapsulated by this DAG Node.", + max_length=64, + default="Operator", + ) + task_name: str = Field( + description="Name of method on Operator (e.g. 'chat')", + max_length=64, + default="chat", + ) + default_task_kwargs: str = Field( + description="Default kwargs for the task as JSON.", + default="{}", + ) class DagNode(DagNodeBase, MetadataMixin, table=True): - operator: Operator = Relationship() + project: DagProject = Relationship(back_populates="nodes") # endregion diff --git a/webapp/api/server.py b/webapp/api/server.py index e51b60b3..06df24eb 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -1,6 +1,7 @@ +import json import os from collections.abc import Callable, Sequence -from typing import Annotated, Any +from typing import Annotated from uuid import UUID import dotenv @@ -12,6 +13,10 @@ Client, ClientCreate, ClientUpdate, + DagNodeBase, + DagNodeCreate, + DagNodeToDagNodeLink, + DagProject, DagProjectCreate, Operator, OperatorCreate, @@ -85,7 +90,9 @@ def ping(): return {"message": "pong"} -# ===CRUD operations for Orchestrators=== # +# region Orchestrators API + + @app.post("/orchestrators/", response_model=Orchestrator) def create_orchestrator(orchestrator: OrchestratorCreate) -> Orchestrator: with Session() as db: @@ -129,7 +136,10 @@ def delete_orchestrator(orchestrator_id: UUID) -> Orchestrator: return orchestrator -# ===CRUD operations for Operators=== # +# endregion +# region Operators API + + @app.post("/operators/") def create_operator(operator: OperatorCreate) -> Operator: with Session() as db: @@ -190,7 +200,10 @@ def delete_operator(orchestrator_id: UUID, operator_id: UUID) -> Operator: return operator -# ===CRUD operations for Clients=== # +# endregion +# region Clients API + + @app.post("/clients/") def create_client(client: ClientCreate) -> Client: with Session() as db: @@ -235,7 +248,7 @@ def read_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> Cl return client -@app.put("/orchestrator/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") +@app.put("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") def update_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID, client: ClientUpdate) -> Client: with Session() as db: db_client = crud.update_client(db, client_id, operator_id, orchestrator_id, client) @@ -244,7 +257,7 @@ def update_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID, cli return db_client -@app.delete("/orchestrator/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") +@app.delete("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> Client: with Session() as db: client = crud.delete_client(db, client_id, operator_id, orchestrator_id) @@ -253,12 +266,13 @@ def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> return client -# ===Project and Operator Building=== # -# TODO: add persistence +# endregion +# region DagProject API +# TODO: integrate better into persistence and Concept Hierarchy -@app.post("/build/project") -def initialize_project(dag_project_create: DagProjectCreate) -> str: +@app.post("/projects/dag/") +def initialize_project(project: DagProjectCreate) -> DagProject: """ Initiate a directed-acyclic-graph (DAG) project locally. Projects must be unique in name. @@ -271,65 +285,159 @@ def initialize_project(dag_project_create: DagProjectCreate) -> str: name: The name of the project to be initialized. """ - name = dag_project_create.name - with Session() as session: - db_project = crud.get_dag_project_by_name(session, name) + name = project.name + with Session() as db: + db_project = crud.get_dag_project_by_name(db, name) if db_project is not None: - raise HTTPException(status_code=400, detail="{name} already exists as a Project!") - db_project = crud.create_dag_project(session, dag_project_create) + raise HTTPException(status_code=400, detail=f"{name} already exists as a Project!") + db_project = crud.create_dag_project(db, project) + + return db_project + + +@app.get("/projects/dag/") +def read_projects(common_read_params: CommonReadDep) -> Sequence[DagProject]: + with Session() as db: + return crud.get_dag_projects( + db, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) - return db_project +@app.get("/projects/dag/{project_name}") +def read_project(project_name: str) -> DagProject: + with Session() as db: + project = crud.get_dag_project_by_name(db, project_name) + if project is None: + raise project_not_found(project_name) + return project -@app.post("/build/project/{project}/task") -def expand_project_with_task( - project: str, operator: str, boost: str = "chat", task: str = "", default_task_kwargs: dict[str, Any] = {} -) -> str: + +@app.delete("/projects/dag/{project_name}") +def delete_project(project_name: str) -> DagProject: + with Session() as db: + project = crud.delete_dag_project_by_name(db, project_name) + if project is None: + raise project_not_found(project_name) + return project + + +@app.post("/projects/dag/{project_name}/tasks") +def expand_project_with_task(project_name: str, task: DagNodeCreate) -> DagProject: """ Expand a project by adding an operator task as a node in its DAG. - project: The name of the project to be expanded. + project_name: The name of the project to be expanded. + name: The name of the task instance this node represents. operator: The name of the operator whose task we'd like to use. - boost: The name of the operator's prompt boost to add as a node. - task: The name of the task this node represents. + task: The name of the operator's task to add as a node. default_task_kwargs: Any default arguments to pass to the task. """ - with Session() as session: - db_project = crud.get_dag_project_by_name(session) - if project not in PROJECTS: - raise project_not_found(project) - project_obj = PROJECTS[project] - node = DAGNode(task, boost, getattr(operators, operator)(), default_task_kwargs) - return project_obj.add_node(task, node).name + if project_name != task.project_name: + raise HTTPException( + status_code=400, + detail=f"Path project name {project_name} and body project name {task.project_name} don't match!", + ) + + with Session() as db: + project = crud.get_dag_project_by_name(db, task.project_name) + if project is None: + raise project_not_found(task.project_name) + node = crud.get_dag_node_by_name(db, project.id, task.name) + if node is not None: + raise HTTPException( + status_code=400, detail=f"{task.name} already exists as a node for {task.project_name}!" + ) + + crud.create_dag_node( + db, + DagNodeBase( + project_id=project.id, + **task.model_dump(exclude=set("project")), + ), + ) -@app.post("/build/project/{project}/edge") -def expand_project_with_connection(project: str, parent: str, child: str, input_to_child: str) -> tuple[str, str, str]: + db.refresh(project) + return project + + +@app.post("/projects/dag/{project_name}/edges") +def expand_project_with_connection(project_name: str, edge: DagNodeToDagNodeLink) -> DagProject: """ Expand a project by connecting two tasks together. The output from the parent task will be fed into the child task. - project: The name of the project to be expanded. - parent: The name of the parent task in the connection. - child: The name of the child task in the connection. + project_name: The name of the project to be expanded. + parent_name: The name of the parent task in the connection. + child_name: The name of the child task in the connection. input_to_child: The name of the input to the child (equivalently, the output from the parent) """ - if project not in PROJECTS: - raise project_not_found(project) - project_obj = PROJECTS[project] - return project_obj.add_edge(parent, child, input_to_child) + if project_name != edge.project_name: + raise HTTPException( + status_code=400, + detail=f"Path project name {project_name} and body project name {edge.project_name} don't match!", + ) + + with Session() as db: + project = crud.get_dag_project_by_name(db, edge.project_name) + if project is None: + raise project_not_found(edge.project_name) + db_edge = crud.get_dag_edge(db, edge.project_name, edge.parent_name, edge.child_name) + if db_edge is not None: + raise HTTPException( + status_code=400, + detail=f"{edge.project_name} already has an edge from {edge.parent_name} to {edge.child_name}!", + ) -@app.post("/build/project/{project}/run") -async def run_project(project: str): + crud.create_dag_edge(db, edge) + + db.refresh(project) + return project + + +@app.post("/projects/dag/{project_name}/run") +async def run_project(project_name: str) -> list[tuple[str, str]]: """ Run a project from its sources to its sinks. project: The name of the project to be run. """ - if project not in PROJECTS: - raise project_not_found(project) - project_obj = PROJECTS[project] - async for operator, response in project_obj.execute(): + # TODO: error handling for cycles + with Session() as session: + db_project = crud.get_dag_project_by_name(session, project_name) + if db_project is None: + raise project_not_found(project_name) + nodes = db_project.nodes + edges = db_project.edges + + project = Project() + for node in nodes: + project.add_node( + node.name, + DAGNode( + node.name, + node.task_name, + getattr(operators, node.operator_name)(), + json.loads(node.default_task_kwargs), + ), + ) + for edge in edges: + project.add_edge( + edge.parent_name, + edge.child_name, + edge.input_to_child, + ) + + result = [] + async for operator, response in project.execute(): print(operator) - print(response) + print(response.text) + result.append((operator, response.text)) + + return result + + +# endregion From 6e91398e53b6a3cf718ffe5865eabdeb59462b5c Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 13 Nov 2024 15:50:01 -0800 Subject: [PATCH 04/14] testless, first iteration of DAGProject API --- src/concrete-async/concrete_async/__init__.py | 2 +- src/concrete-async/concrete_async/celery.py | 1 + src/concrete-core/concrete/__init__.py | 4 ++-- tests/test_tools.py | 3 ++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/concrete-async/concrete_async/__init__.py b/src/concrete-async/concrete_async/__init__.py index 7c5d380e..7ca05b50 100644 --- a/src/concrete-async/concrete_async/__init__.py +++ b/src/concrete-async/concrete_async/__init__.py @@ -3,10 +3,10 @@ from celery.result import AsyncResult from concrete.clients import CLIClient, model_to_schema -from concrete.models import KombuMixin, Message, Operation from concrete_async.tasks import abstract_operation import concrete +from concrete.models import KombuMixin, Message, Operation def _delay_factory(string_func: Callable[..., str]) -> Callable[..., AsyncResult]: diff --git a/src/concrete-async/concrete_async/celery.py b/src/concrete-async/concrete_async/celery.py index b191ab4e..a26f9586 100644 --- a/src/concrete-async/concrete_async/celery.py +++ b/src/concrete-async/concrete_async/celery.py @@ -1,4 +1,5 @@ from celery import Celery + from concrete.models import clients # noqa from . import celeryconfig diff --git a/src/concrete-core/concrete/__init__.py b/src/concrete-core/concrete/__init__.py index df0103e6..30b54778 100644 --- a/src/concrete-core/concrete/__init__.py +++ b/src/concrete-core/concrete/__init__.py @@ -1,8 +1,8 @@ from dotenv import load_dotenv -from . import operators, orchestrators +from . import abstract, models, operators, orchestrators # Always runs even when importing submodules # https://stackoverflow.com/a/27144933 load_dotenv(override=True) -__all__ = ["operators", "orchestrators"] +__all__ = ["abstract", "models", "operators", "orchestrators"] diff --git a/tests/test_tools.py b/tests/test_tools.py index 74212985..7e244500 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -2,11 +2,12 @@ from unittest.mock import Mock, patch import pytest -from concrete.models import messages from concrete.tools.http import HTTPTool from concrete.tools.utils import invoke_tool from requests.exceptions import HTTPError +from concrete.models import messages + def test_http_tool_process_response_ok(): mock_response = Mock() From fc50f4d4cac6f764a31b6ac0fd55abda3dd4e388 Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Fri, 15 Nov 2024 01:43:54 -0800 Subject: [PATCH 05/14] used JSON columns for default kwargs and options --- pyproject.toml | 1 + src/concrete-db/concrete_db/orm/models.py | 26 ++- uv.lock | 241 ++++++++++++++++++++++ webapp/api/server.py | 5 +- 4 files changed, 262 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b2aebfb..d9323058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dev-dependencies = [ "concrete-core", "concrete-async", "concrete-db", + "fastapi[standard]>=0.115.4", ] [tool.uv.workspace] diff --git a/src/concrete-db/concrete_db/orm/models.py b/src/concrete-db/concrete_db/orm/models.py index f677bbe8..57e0cd54 100644 --- a/src/concrete-db/concrete_db/orm/models.py +++ b/src/concrete-db/concrete_db/orm/models.py @@ -10,10 +10,10 @@ from concrete.tools import MetaTool from concrete.tools.utils import tool_name_to_class from pydantic import ConfigDict, ValidationError, model_validator -from sqlalchemy import CheckConstraint, DateTime, UniqueConstraint +from sqlalchemy import CheckConstraint, Column, DateTime, UniqueConstraint from sqlalchemy.schema import Index from sqlalchemy.sql import func -from sqlmodel import Field, Relationship, SQLModel +from sqlmodel import JSON, Field, Relationship, SQLModel from .setup import SQLALCHEMY_DATABASE_URL, engine @@ -69,7 +69,7 @@ class DagNodeToDagNodeLink(Base, table=True): project_name: str = Field(foreign_key="dagproject.name", primary_key=True, index=True, ondelete="CASCADE") parent_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") child_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") - input_to_child: str = Field(description="Name of the argument to the child task") + input_to_child: str = Field(description="Name of the argument to the child task", default="message") project: "DagProject" = Relationship(back_populates="edges") @@ -378,17 +378,19 @@ class DagNodeBase(Base): operator_name: str = Field( description="Name of Operator encapsulated by this DAG Node.", max_length=64, - default="Operator", ) task_name: str = Field( description="Name of method on Operator (e.g. 'chat')", max_length=64, - default="chat", ) - default_task_kwargs: str = Field( + default_task_kwargs: dict = Field( description="Default kwargs for the task as JSON.", - default="{}", + sa_column=Column(JSON), + ) + options: dict = Field( + description="Options to run the task with. Includes tools, response format, etc.", + sa_column=Column(JSON), ) __table_args__ = (UniqueConstraint("name", "project_id", name="no_duplicate_names_per_project"),) @@ -421,9 +423,15 @@ class DagNodeCreate(Base): max_length=64, default="chat", ) - default_task_kwargs: str = Field( + default_task_kwargs: dict = Field( description="Default kwargs for the task as JSON.", - default="{}", + sa_column=Column(JSON), + default={"message": "Hi, how are you?"}, + ) + options: dict = Field( + description="Options to run the task with. Includes tools, response format, etc.", + sa_column=Column(JSON), + default_factory=dict, ) diff --git a/uv.lock b/uv.lock index 316dc89e..9048f57d 100644 --- a/uv.lock +++ b/uv.lock @@ -22,6 +22,7 @@ requirements = [ { name = "concrete-async" }, { name = "concrete-core" }, { name = "concrete-db" }, + { name = "fastapi", extras = ["standard"], specifier = ">=0.115.4" }, { name = "flake8", specifier = "==7.1.0" }, { name = "ipykernel", specifier = "==6.29.5" }, { name = "isort", specifier = "==5.13.2" }, @@ -703,6 +704,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, ] +[[package]] +name = "dnspython" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/4a/263763cb2ba3816dd94b08ad3a33d5fdae34ecb856678773cc40a3605829/dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1", size = 345197 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, +] + [[package]] name = "docs" version = "0.1.0" @@ -718,6 +728,19 @@ requires-dist = [ { name = "pymdown-extensions", specifier = ">=10.12" }, ] +[[package]] +name = "email-validator" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dnspython" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ce/13508a1ec3f8bb981ae4ca79ea40384becc868bfae97fd1c942bb3a001b1/email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7", size = 48967 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/ee/bf0adb559ad3c786f12bcbc9296b3f5675f529199bef03e2df281fa1fadb/email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631", size = 33521 }, +] + [[package]] name = "executing" version = "2.1.0" @@ -741,6 +764,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/f6/af0d1f58f86002be0cf1e2665cdd6f7a4a71cdc8a7a9438cdc9e3b5375fe/fastapi-0.115.4-py3-none-any.whl", hash = "sha256:0b504a063ffb3cf96a5e27dc1bc32c80ca743a2528574f9cdc77daa2d31b4742", size = 94732 }, ] +[package.optional-dependencies] +standard = [ + { name = "email-validator" }, + { name = "fastapi-cli", extra = ["standard"] }, + { name = "httpx" }, + { name = "jinja2" }, + { name = "python-multipart" }, + { name = "uvicorn", extra = ["standard"] }, +] + +[[package]] +name = "fastapi-cli" +version = "0.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typer" }, + { name = "uvicorn", extra = ["standard"] }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/f8/1ad5ce32d029aeb9117e9a5a9b3e314a8477525d60c12a9b7730a3c186ec/fastapi_cli-0.0.5.tar.gz", hash = "sha256:d30e1239c6f46fcb95e606f02cdda59a1e2fa778a54b64686b3ff27f6211ff9f", size = 15571 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/ea/4b5011012ac925fe2f83b19d0e09cee9d324141ec7bf5e78bb2817f96513/fastapi_cli-0.0.5-py3-none-any.whl", hash = "sha256:e94d847524648c748a5350673546bbf9bcaeb086b33c24f2e82e021436866a46", size = 9489 }, +] + +[package.optional-dependencies] +standard = [ + { name = "uvicorn", extra = ["standard"] }, +] + [[package]] name = "fastjsonschema" version = "2.20.0" @@ -966,6 +1017,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/6c/d2fbdaaa5959339d53ba38e94c123e4e84b8fbc4b84beb0e70d7c1608486/httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc", size = 96854 }, ] +[[package]] +name = "httptools" +version = "0.6.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/9a/ce5e1f7e131522e6d3426e8e7a490b3a01f39a6696602e1c4f33f9e94277/httptools-0.6.4.tar.gz", hash = "sha256:4e93eee4add6493b59a5c514da98c939b244fce4a0d8879cd3f466562f4b7d5c", size = 240639 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/26/bb526d4d14c2774fe07113ca1db7255737ffbb119315839af2065abfdac3/httptools-0.6.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f47f8ed67cc0ff862b84a1189831d1d33c963fb3ce1ee0c65d3b0cbe7b711069", size = 199029 }, + { url = "https://files.pythonhosted.org/packages/a6/17/3e0d3e9b901c732987a45f4f94d4e2c62b89a041d93db89eafb262afd8d5/httptools-0.6.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0614154d5454c21b6410fdf5262b4a3ddb0f53f1e1721cfd59d55f32138c578a", size = 103492 }, + { url = "https://files.pythonhosted.org/packages/b7/24/0fe235d7b69c42423c7698d086d4db96475f9b50b6ad26a718ef27a0bce6/httptools-0.6.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f8787367fbdfccae38e35abf7641dafc5310310a5987b689f4c32cc8cc3ee975", size = 462891 }, + { url = "https://files.pythonhosted.org/packages/b1/2f/205d1f2a190b72da6ffb5f41a3736c26d6fa7871101212b15e9b5cd8f61d/httptools-0.6.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40b0f7fe4fd38e6a507bdb751db0379df1e99120c65fbdc8ee6c1d044897a636", size = 459788 }, + { url = "https://files.pythonhosted.org/packages/6e/4c/d09ce0eff09057a206a74575ae8f1e1e2f0364d20e2442224f9e6612c8b9/httptools-0.6.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:40a5ec98d3f49904b9fe36827dcf1aadfef3b89e2bd05b0e35e94f97c2b14721", size = 433214 }, + { url = "https://files.pythonhosted.org/packages/3e/d2/84c9e23edbccc4a4c6f96a1b8d99dfd2350289e94f00e9ccc7aadde26fb5/httptools-0.6.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dacdd3d10ea1b4ca9df97a0a303cbacafc04b5cd375fa98732678151643d4988", size = 434120 }, + { url = "https://files.pythonhosted.org/packages/d0/46/4d8e7ba9581416de1c425b8264e2cadd201eb709ec1584c381f3e98f51c1/httptools-0.6.4-cp311-cp311-win_amd64.whl", hash = "sha256:288cd628406cc53f9a541cfaf06041b4c71d751856bab45e3702191f931ccd17", size = 88565 }, + { url = "https://files.pythonhosted.org/packages/bb/0e/d0b71465c66b9185f90a091ab36389a7352985fe857e352801c39d6127c8/httptools-0.6.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:df017d6c780287d5c80601dafa31f17bddb170232d85c066604d8558683711a2", size = 200683 }, + { url = "https://files.pythonhosted.org/packages/e2/b8/412a9bb28d0a8988de3296e01efa0bd62068b33856cdda47fe1b5e890954/httptools-0.6.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:85071a1e8c2d051b507161f6c3e26155b5c790e4e28d7f236422dbacc2a9cc44", size = 104337 }, + { url = "https://files.pythonhosted.org/packages/9b/01/6fb20be3196ffdc8eeec4e653bc2a275eca7f36634c86302242c4fbb2760/httptools-0.6.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69422b7f458c5af875922cdb5bd586cc1f1033295aa9ff63ee196a87519ac8e1", size = 508796 }, + { url = "https://files.pythonhosted.org/packages/f7/d8/b644c44acc1368938317d76ac991c9bba1166311880bcc0ac297cb9d6bd7/httptools-0.6.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16e603a3bff50db08cd578d54f07032ca1631450ceb972c2f834c2b860c28ea2", size = 510837 }, + { url = "https://files.pythonhosted.org/packages/52/d8/254d16a31d543073a0e57f1c329ca7378d8924e7e292eda72d0064987486/httptools-0.6.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec4f178901fa1834d4a060320d2f3abc5c9e39766953d038f1458cb885f47e81", size = 485289 }, + { url = "https://files.pythonhosted.org/packages/5f/3c/4aee161b4b7a971660b8be71a92c24d6c64372c1ab3ae7f366b3680df20f/httptools-0.6.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb89ecf8b290f2e293325c646a211ff1c2493222798bb80a530c5e7502494f", size = 489779 }, + { url = "https://files.pythonhosted.org/packages/12/b7/5cae71a8868e555f3f67a50ee7f673ce36eac970f029c0c5e9d584352961/httptools-0.6.4-cp312-cp312-win_amd64.whl", hash = "sha256:db78cb9ca56b59b016e64b6031eda5653be0589dba2b1b43453f6e8b405a0970", size = 88634 }, + { url = "https://files.pythonhosted.org/packages/94/a3/9fe9ad23fd35f7de6b91eeb60848986058bd8b5a5c1e256f5860a160cc3e/httptools-0.6.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ade273d7e767d5fae13fa637f4d53b6e961fb7fd93c7797562663f0171c26660", size = 197214 }, + { url = "https://files.pythonhosted.org/packages/ea/d9/82d5e68bab783b632023f2fa31db20bebb4e89dfc4d2293945fd68484ee4/httptools-0.6.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:856f4bc0478ae143bad54a4242fccb1f3f86a6e1be5548fecfd4102061b3a083", size = 102431 }, + { url = "https://files.pythonhosted.org/packages/96/c1/cb499655cbdbfb57b577734fde02f6fa0bbc3fe9fb4d87b742b512908dff/httptools-0.6.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:322d20ea9cdd1fa98bd6a74b77e2ec5b818abdc3d36695ab402a0de8ef2865a3", size = 473121 }, + { url = "https://files.pythonhosted.org/packages/af/71/ee32fd358f8a3bb199b03261f10921716990808a675d8160b5383487a317/httptools-0.6.4-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d87b29bd4486c0093fc64dea80231f7c7f7eb4dc70ae394d70a495ab8436071", size = 473805 }, + { url = "https://files.pythonhosted.org/packages/8a/0a/0d4df132bfca1507114198b766f1737d57580c9ad1cf93c1ff673e3387be/httptools-0.6.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:342dd6946aa6bda4b8f18c734576106b8a31f2fe31492881a9a160ec84ff4bd5", size = 448858 }, + { url = "https://files.pythonhosted.org/packages/1e/6a/787004fdef2cabea27bad1073bf6a33f2437b4dbd3b6fb4a9d71172b1c7c/httptools-0.6.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b36913ba52008249223042dca46e69967985fb4051951f94357ea681e1f5dc0", size = 452042 }, + { url = "https://files.pythonhosted.org/packages/4d/dc/7decab5c404d1d2cdc1bb330b1bf70e83d6af0396fd4fc76fc60c0d522bf/httptools-0.6.4-cp313-cp313-win_amd64.whl", hash = "sha256:28908df1b9bb8187393d5b5db91435ccc9c8e891657f9cbb42a2541b44c82fc8", size = 87682 }, +] + [[package]] name = "httpx" version = "0.27.2" @@ -2152,6 +2232,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/35/a6/145655273568ee78a581e734cf35beb9e33a370b29c5d3c8fee3744de29f/python_json_logger-2.0.7-py3-none-any.whl", hash = "sha256:f380b826a991ebbe3de4d897aeec42760035ac760345e57b812938dc8b35e2bd", size = 8067 }, ] +[[package]] +name = "python-multipart" +version = "0.0.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/40/22/edea41c2d4a22e666c0c7db7acdcbf7bc8c1c1f7d3b3ca246ec982fec612/python_multipart-0.0.17.tar.gz", hash = "sha256:41330d831cae6e2f22902704ead2826ea038d0419530eadff3ea80175aec5538", size = 36452 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/fb/275137a799169392f1fa88fff2be92f16eee38e982720a8aaadefc4a36b2/python_multipart-0.0.17-py3-none-any.whl", hash = "sha256:15dc4f487e0a9476cc1201261188ee0940165cffc94429b6fc565c4d3045cb5d", size = 24453 }, +] + [[package]] name = "pywin32" version = "308" @@ -2446,6 +2535,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/12/282ee9bce8b58130cb762fbc9beabd531549952cac11fc56add11dcb7ea0/setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd", size = 1251070 }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, +] + [[package]] name = "six" version = "1.16.0" @@ -2626,6 +2724,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, ] +[[package]] +name = "typer" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/87/9eb07fdfa14e22ec7658b5b1147836d22df3848a22c85a4e18ed272303a5/typer-0.13.0.tar.gz", hash = "sha256:f1c7198347939361eec90139ffa0fd8b3df3a2259d5852a0f7400e476d95985c", size = 97572 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/7e/c8bfa8cbcd3ea1d25d2beb359b5c5a3f4339a7e2e5d9e3ef3e29ba3ab3b9/typer-0.13.0-py3-none-any.whl", hash = "sha256:d85fe0b777b2517cc99c8055ed735452f2659cd45e451507c76f48ce5c1d00e2", size = 44194 }, +] + [[package]] name = "types-awscrt" version = "0.23.0" @@ -2711,6 +2824,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/14/78bd0e95dd2444b6caacbca2b730671d4295ccb628ef58b81bee903629df/uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82", size = 63723 }, ] +[package.optional-dependencies] +standard = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "httptools" }, + { name = "python-dotenv" }, + { name = "pyyaml" }, + { name = "uvloop", marker = "platform_python_implementation != 'PyPy' and sys_platform != 'cygwin' and sys_platform != 'win32'" }, + { name = "watchfiles" }, + { name = "websockets" }, +] + +[[package]] +name = "uvloop" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/c0/854216d09d33c543f12a44b393c402e89a920b1a0a7dc634c42de91b9cf6/uvloop-0.21.0.tar.gz", hash = "sha256:3bf12b0fda68447806a7ad847bfa591613177275d35b6724b1ee573faa3704e3", size = 2492741 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/a7/4cf0334105c1160dd6819f3297f8700fda7fc30ab4f61fbf3e725acbc7cc/uvloop-0.21.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c0f3fa6200b3108919f8bdabb9a7f87f20e7097ea3c543754cabc7d717d95cf8", size = 1447410 }, + { url = "https://files.pythonhosted.org/packages/8c/7c/1517b0bbc2dbe784b563d6ab54f2ef88c890fdad77232c98ed490aa07132/uvloop-0.21.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0878c2640cf341b269b7e128b1a5fed890adc4455513ca710d77d5e93aa6d6a0", size = 805476 }, + { url = "https://files.pythonhosted.org/packages/ee/ea/0bfae1aceb82a503f358d8d2fa126ca9dbdb2ba9c7866974faec1cb5875c/uvloop-0.21.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9fb766bb57b7388745d8bcc53a359b116b8a04c83a2288069809d2b3466c37e", size = 3960855 }, + { url = "https://files.pythonhosted.org/packages/8a/ca/0864176a649838b838f36d44bf31c451597ab363b60dc9e09c9630619d41/uvloop-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8a375441696e2eda1c43c44ccb66e04d61ceeffcd76e4929e527b7fa401b90fb", size = 3973185 }, + { url = "https://files.pythonhosted.org/packages/30/bf/08ad29979a936d63787ba47a540de2132169f140d54aa25bc8c3df3e67f4/uvloop-0.21.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:baa0e6291d91649c6ba4ed4b2f982f9fa165b5bbd50a9e203c416a2797bab3c6", size = 3820256 }, + { url = "https://files.pythonhosted.org/packages/da/e2/5cf6ef37e3daf2f06e651aae5ea108ad30df3cb269102678b61ebf1fdf42/uvloop-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4509360fcc4c3bd2c70d87573ad472de40c13387f5fda8cb58350a1d7475e58d", size = 3937323 }, + { url = "https://files.pythonhosted.org/packages/8c/4c/03f93178830dc7ce8b4cdee1d36770d2f5ebb6f3d37d354e061eefc73545/uvloop-0.21.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:359ec2c888397b9e592a889c4d72ba3d6befba8b2bb01743f72fffbde663b59c", size = 1471284 }, + { url = "https://files.pythonhosted.org/packages/43/3e/92c03f4d05e50f09251bd8b2b2b584a2a7f8fe600008bcc4523337abe676/uvloop-0.21.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f7089d2dc73179ce5ac255bdf37c236a9f914b264825fdaacaded6990a7fb4c2", size = 821349 }, + { url = "https://files.pythonhosted.org/packages/a6/ef/a02ec5da49909dbbfb1fd205a9a1ac4e88ea92dcae885e7c961847cd51e2/uvloop-0.21.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa4dcdbd9ae0a372f2167a207cd98c9f9a1ea1188a8a526431eef2f8116cc8d", size = 4580089 }, + { url = "https://files.pythonhosted.org/packages/06/a7/b4e6a19925c900be9f98bec0a75e6e8f79bb53bdeb891916609ab3958967/uvloop-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86975dca1c773a2c9864f4c52c5a55631038e387b47eaf56210f873887b6c8dc", size = 4693770 }, + { url = "https://files.pythonhosted.org/packages/ce/0c/f07435a18a4b94ce6bd0677d8319cd3de61f3a9eeb1e5f8ab4e8b5edfcb3/uvloop-0.21.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:461d9ae6660fbbafedd07559c6a2e57cd553b34b0065b6550685f6653a98c1cb", size = 4451321 }, + { url = "https://files.pythonhosted.org/packages/8f/eb/f7032be105877bcf924709c97b1bf3b90255b4ec251f9340cef912559f28/uvloop-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:183aef7c8730e54c9a3ee3227464daed66e37ba13040bb3f350bc2ddc040f22f", size = 4659022 }, + { url = "https://files.pythonhosted.org/packages/3f/8d/2cbef610ca21539f0f36e2b34da49302029e7c9f09acef0b1c3b5839412b/uvloop-0.21.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:bfd55dfcc2a512316e65f16e503e9e450cab148ef11df4e4e679b5e8253a5281", size = 1468123 }, + { url = "https://files.pythonhosted.org/packages/93/0d/b0038d5a469f94ed8f2b2fce2434a18396d8fbfb5da85a0a9781ebbdec14/uvloop-0.21.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:787ae31ad8a2856fc4e7c095341cccc7209bd657d0e71ad0dc2ea83c4a6fa8af", size = 819325 }, + { url = "https://files.pythonhosted.org/packages/50/94/0a687f39e78c4c1e02e3272c6b2ccdb4e0085fda3b8352fecd0410ccf915/uvloop-0.21.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ee4d4ef48036ff6e5cfffb09dd192c7a5027153948d85b8da7ff705065bacc6", size = 4582806 }, + { url = "https://files.pythonhosted.org/packages/d2/19/f5b78616566ea68edd42aacaf645adbf71fbd83fc52281fba555dc27e3f1/uvloop-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3df876acd7ec037a3d005b3ab85a7e4110422e4d9c1571d4fc89b0fc41b6816", size = 4701068 }, + { url = "https://files.pythonhosted.org/packages/47/57/66f061ee118f413cd22a656de622925097170b9380b30091b78ea0c6ea75/uvloop-0.21.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd53ecc9a0f3d87ab847503c2e1552b690362e005ab54e8a48ba97da3924c0dc", size = 4454428 }, + { url = "https://files.pythonhosted.org/packages/63/9a/0962b05b308494e3202d3f794a6e85abe471fe3cafdbcf95c2e8c713aabd/uvloop-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a5c39f217ab3c663dc699c04cbd50c13813e31d917642d459fdcec07555cc553", size = 4660018 }, +] + [[package]] name = "vine" version = "5.1.0" @@ -2761,6 +2911,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067 }, ] +[[package]] +name = "watchfiles" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c8/27/2ba23c8cc85796e2d41976439b08d52f691655fdb9401362099502d1f0cf/watchfiles-0.24.0.tar.gz", hash = "sha256:afb72325b74fa7a428c009c1b8be4b4d7c2afedafb2982827ef2156646df2fe1", size = 37870 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/02/366ae902cd81ca5befcd1854b5c7477b378f68861597cef854bd6dc69fbe/watchfiles-0.24.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:bdcd5538e27f188dd3c804b4a8d5f52a7fc7f87e7fd6b374b8e36a4ca03db428", size = 375579 }, + { url = "https://files.pythonhosted.org/packages/bc/67/d8c9d256791fe312fea118a8a051411337c948101a24586e2df237507976/watchfiles-0.24.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2dadf8a8014fde6addfd3c379e6ed1a981c8f0a48292d662e27cabfe4239c83c", size = 367726 }, + { url = "https://files.pythonhosted.org/packages/b1/dc/a8427b21ef46386adf824a9fec4be9d16a475b850616cfd98cf09a97a2ef/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6509ed3f467b79d95fc62a98229f79b1a60d1b93f101e1c61d10c95a46a84f43", size = 437735 }, + { url = "https://files.pythonhosted.org/packages/3a/21/0b20bef581a9fbfef290a822c8be645432ceb05fb0741bf3c032e0d90d9a/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8360f7314a070c30e4c976b183d1d8d1585a4a50c5cb603f431cebcbb4f66327", size = 433644 }, + { url = "https://files.pythonhosted.org/packages/1c/e8/d5e5f71cc443c85a72e70b24269a30e529227986096abe091040d6358ea9/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:316449aefacf40147a9efaf3bd7c9bdd35aaba9ac5d708bd1eb5763c9a02bef5", size = 450928 }, + { url = "https://files.pythonhosted.org/packages/61/ee/bf17f5a370c2fcff49e1fec987a6a43fd798d8427ea754ce45b38f9e117a/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73bde715f940bea845a95247ea3e5eb17769ba1010efdc938ffcb967c634fa61", size = 469072 }, + { url = "https://files.pythonhosted.org/packages/a3/34/03b66d425986de3fc6077e74a74c78da298f8cb598887f664a4485e55543/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3770e260b18e7f4e576edca4c0a639f704088602e0bc921c5c2e721e3acb8d15", size = 475517 }, + { url = "https://files.pythonhosted.org/packages/70/eb/82f089c4f44b3171ad87a1b433abb4696f18eb67292909630d886e073abe/watchfiles-0.24.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0fd7248cf533c259e59dc593a60973a73e881162b1a2f73360547132742823", size = 425480 }, + { url = "https://files.pythonhosted.org/packages/53/20/20509c8f5291e14e8a13104b1808cd7cf5c44acd5feaecb427a49d387774/watchfiles-0.24.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d7a2e3b7f5703ffbd500dabdefcbc9eafeff4b9444bbdd5d83d79eedf8428fab", size = 612322 }, + { url = "https://files.pythonhosted.org/packages/df/2b/5f65014a8cecc0a120f5587722068a975a692cadbe9fe4ea56b3d8e43f14/watchfiles-0.24.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d831ee0a50946d24a53821819b2327d5751b0c938b12c0653ea5be7dea9c82ec", size = 595094 }, + { url = "https://files.pythonhosted.org/packages/18/98/006d8043a82c0a09d282d669c88e587b3a05cabdd7f4900e402250a249ac/watchfiles-0.24.0-cp311-none-win32.whl", hash = "sha256:49d617df841a63b4445790a254013aea2120357ccacbed00253f9c2b5dc24e2d", size = 264191 }, + { url = "https://files.pythonhosted.org/packages/8a/8b/badd9247d6ec25f5f634a9b3d0d92e39c045824ec7e8afcedca8ee52c1e2/watchfiles-0.24.0-cp311-none-win_amd64.whl", hash = "sha256:d3dcb774e3568477275cc76554b5a565024b8ba3a0322f77c246bc7111c5bb9c", size = 277527 }, + { url = "https://files.pythonhosted.org/packages/af/19/35c957c84ee69d904299a38bae3614f7cede45f07f174f6d5a2f4dbd6033/watchfiles-0.24.0-cp311-none-win_arm64.whl", hash = "sha256:9301c689051a4857d5b10777da23fafb8e8e921bcf3abe6448a058d27fb67633", size = 266253 }, + { url = "https://files.pythonhosted.org/packages/35/82/92a7bb6dc82d183e304a5f84ae5437b59ee72d48cee805a9adda2488b237/watchfiles-0.24.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:7211b463695d1e995ca3feb38b69227e46dbd03947172585ecb0588f19b0d87a", size = 374137 }, + { url = "https://files.pythonhosted.org/packages/87/91/49e9a497ddaf4da5e3802d51ed67ff33024597c28f652b8ab1e7c0f5718b/watchfiles-0.24.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b8693502d1967b00f2fb82fc1e744df128ba22f530e15b763c8d82baee15370", size = 367733 }, + { url = "https://files.pythonhosted.org/packages/0d/d8/90eb950ab4998effea2df4cf3a705dc594f6bc501c5a353073aa990be965/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdab9555053399318b953a1fe1f586e945bc8d635ce9d05e617fd9fe3a4687d6", size = 437322 }, + { url = "https://files.pythonhosted.org/packages/6c/a2/300b22e7bc2a222dd91fce121cefa7b49aa0d26a627b2777e7bdfcf1110b/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34e19e56d68b0dad5cff62273107cf5d9fbaf9d75c46277aa5d803b3ef8a9e9b", size = 433409 }, + { url = "https://files.pythonhosted.org/packages/99/44/27d7708a43538ed6c26708bcccdde757da8b7efb93f4871d4cc39cffa1cc/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:41face41f036fee09eba33a5b53a73e9a43d5cb2c53dad8e61fa6c9f91b5a51e", size = 452142 }, + { url = "https://files.pythonhosted.org/packages/b0/ec/c4e04f755be003129a2c5f3520d2c47026f00da5ecb9ef1e4f9449637571/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5148c2f1ea043db13ce9b0c28456e18ecc8f14f41325aa624314095b6aa2e9ea", size = 469414 }, + { url = "https://files.pythonhosted.org/packages/c5/4e/cdd7de3e7ac6432b0abf282ec4c1a1a2ec62dfe423cf269b86861667752d/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e4bd963a935aaf40b625c2499f3f4f6bbd0c3776f6d3bc7c853d04824ff1c9f", size = 472962 }, + { url = "https://files.pythonhosted.org/packages/27/69/e1da9d34da7fc59db358424f5d89a56aaafe09f6961b64e36457a80a7194/watchfiles-0.24.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c79d7719d027b7a42817c5d96461a99b6a49979c143839fc37aa5748c322f234", size = 425705 }, + { url = "https://files.pythonhosted.org/packages/e8/c1/24d0f7357be89be4a43e0a656259676ea3d7a074901f47022f32e2957798/watchfiles-0.24.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:32aa53a9a63b7f01ed32e316e354e81e9da0e6267435c7243bf8ae0f10b428ef", size = 612851 }, + { url = "https://files.pythonhosted.org/packages/c7/af/175ba9b268dec56f821639c9893b506c69fd999fe6a2e2c51de420eb2f01/watchfiles-0.24.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ce72dba6a20e39a0c628258b5c308779b8697f7676c254a845715e2a1039b968", size = 594868 }, + { url = "https://files.pythonhosted.org/packages/44/81/1f701323a9f70805bc81c74c990137123344a80ea23ab9504a99492907f8/watchfiles-0.24.0-cp312-none-win32.whl", hash = "sha256:d9018153cf57fc302a2a34cb7564870b859ed9a732d16b41a9b5cb2ebed2d444", size = 264109 }, + { url = "https://files.pythonhosted.org/packages/b4/0b/32cde5bc2ebd9f351be326837c61bdeb05ad652b793f25c91cac0b48a60b/watchfiles-0.24.0-cp312-none-win_amd64.whl", hash = "sha256:551ec3ee2a3ac9cbcf48a4ec76e42c2ef938a7e905a35b42a1267fa4b1645896", size = 277055 }, + { url = "https://files.pythonhosted.org/packages/4b/81/daade76ce33d21dbec7a15afd7479de8db786e5f7b7d249263b4ea174e08/watchfiles-0.24.0-cp312-none-win_arm64.whl", hash = "sha256:b52a65e4ea43c6d149c5f8ddb0bef8d4a1e779b77591a458a893eb416624a418", size = 266169 }, + { url = "https://files.pythonhosted.org/packages/30/dc/6e9f5447ae14f645532468a84323a942996d74d5e817837a5c8ce9d16c69/watchfiles-0.24.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:3d2e3ab79a1771c530233cadfd277fcc762656d50836c77abb2e5e72b88e3a48", size = 373764 }, + { url = "https://files.pythonhosted.org/packages/79/c0/c3a9929c372816c7fc87d8149bd722608ea58dc0986d3ef7564c79ad7112/watchfiles-0.24.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:327763da824817b38ad125dcd97595f942d720d32d879f6c4ddf843e3da3fe90", size = 367873 }, + { url = "https://files.pythonhosted.org/packages/2e/11/ff9a4445a7cfc1c98caf99042df38964af12eed47d496dd5d0d90417349f/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd82010f8ab451dabe36054a1622870166a67cf3fce894f68895db6f74bbdc94", size = 438381 }, + { url = "https://files.pythonhosted.org/packages/48/a3/763ba18c98211d7bb6c0f417b2d7946d346cdc359d585cc28a17b48e964b/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d64ba08db72e5dfd5c33be1e1e687d5e4fcce09219e8aee893a4862034081d4e", size = 432809 }, + { url = "https://files.pythonhosted.org/packages/30/4c/616c111b9d40eea2547489abaf4ffc84511e86888a166d3a4522c2ba44b5/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1cf1f6dd7825053f3d98f6d33f6464ebdd9ee95acd74ba2c34e183086900a827", size = 451801 }, + { url = "https://files.pythonhosted.org/packages/b6/be/d7da83307863a422abbfeb12903a76e43200c90ebe5d6afd6a59d158edea/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43e3e37c15a8b6fe00c1bce2473cfa8eb3484bbeecf3aefbf259227e487a03df", size = 468886 }, + { url = "https://files.pythonhosted.org/packages/1d/d3/3dfe131ee59d5e90b932cf56aba5c996309d94dafe3d02d204364c23461c/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88bcd4d0fe1d8ff43675360a72def210ebad3f3f72cabfeac08d825d2639b4ab", size = 472973 }, + { url = "https://files.pythonhosted.org/packages/42/6c/279288cc5653a289290d183b60a6d80e05f439d5bfdfaf2d113738d0f932/watchfiles-0.24.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:999928c6434372fde16c8f27143d3e97201160b48a614071261701615a2a156f", size = 425282 }, + { url = "https://files.pythonhosted.org/packages/d6/d7/58afe5e85217e845edf26d8780c2d2d2ae77675eeb8d1b8b8121d799ce52/watchfiles-0.24.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:30bbd525c3262fd9f4b1865cb8d88e21161366561cd7c9e1194819e0a33ea86b", size = 612540 }, + { url = "https://files.pythonhosted.org/packages/6d/d5/b96eeb9fe3fda137200dd2f31553670cbc731b1e13164fd69b49870b76ec/watchfiles-0.24.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:edf71b01dec9f766fb285b73930f95f730bb0943500ba0566ae234b5c1618c18", size = 593625 }, + { url = "https://files.pythonhosted.org/packages/c1/e5/c326fe52ee0054107267608d8cea275e80be4455b6079491dfd9da29f46f/watchfiles-0.24.0-cp313-none-win32.whl", hash = "sha256:f4c96283fca3ee09fb044f02156d9570d156698bc3734252175a38f0e8975f07", size = 263899 }, + { url = "https://files.pythonhosted.org/packages/a6/8b/8a7755c5e7221bb35fe4af2dc44db9174f90ebf0344fd5e9b1e8b42d381e/watchfiles-0.24.0-cp313-none-win_amd64.whl", hash = "sha256:a974231b4fdd1bb7f62064a0565a6b107d27d21d9acb50c484d2cdba515b9366", size = 276622 }, +] + [[package]] name = "wcwidth" version = "0.2.13" @@ -2796,3 +2995,45 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf594 wheels = [ { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 }, ] + +[[package]] +name = "websockets" +version = "14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/1b/380b883ce05bb5f45a905b61790319a28958a9ab1e4b6b95ff5464b60ca1/websockets-14.1.tar.gz", hash = "sha256:398b10c77d471c0aab20a845e7a60076b6390bfdaac7a6d2edb0d2c59d75e8d8", size = 162840 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/ed/c0d03cb607b7fe1f7ff45e2cd4bb5cd0f9e3299ced79c2c303a6fff44524/websockets-14.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:449d77d636f8d9c17952628cc7e3b8faf6e92a17ec581ec0c0256300717e1512", size = 161949 }, + { url = "https://files.pythonhosted.org/packages/06/91/bf0a44e238660d37a2dda1b4896235d20c29a2d0450f3a46cd688f43b239/websockets-14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a35f704be14768cea9790d921c2c1cc4fc52700410b1c10948511039be824aac", size = 159606 }, + { url = "https://files.pythonhosted.org/packages/ff/b8/7185212adad274c2b42b6a24e1ee6b916b7809ed611cbebc33b227e5c215/websockets-14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b1f3628a0510bd58968c0f60447e7a692933589b791a6b572fcef374053ca280", size = 159854 }, + { url = "https://files.pythonhosted.org/packages/5a/8a/0849968d83474be89c183d8ae8dcb7f7ada1a3c24f4d2a0d7333c231a2c3/websockets-14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c3deac3748ec73ef24fc7be0b68220d14d47d6647d2f85b2771cb35ea847aa1", size = 169402 }, + { url = "https://files.pythonhosted.org/packages/bd/4f/ef886e37245ff6b4a736a09b8468dae05d5d5c99de1357f840d54c6f297d/websockets-14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7048eb4415d46368ef29d32133134c513f507fff7d953c18c91104738a68c3b3", size = 168406 }, + { url = "https://files.pythonhosted.org/packages/11/43/e2dbd4401a63e409cebddedc1b63b9834de42f51b3c84db885469e9bdcef/websockets-14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6cf0ad281c979306a6a34242b371e90e891bce504509fb6bb5246bbbf31e7b6", size = 168776 }, + { url = "https://files.pythonhosted.org/packages/6d/d6/7063e3f5c1b612e9f70faae20ebaeb2e684ffa36cb959eb0862ee2809b32/websockets-14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cc1fc87428c1d18b643479caa7b15db7d544652e5bf610513d4a3478dbe823d0", size = 169083 }, + { url = "https://files.pythonhosted.org/packages/49/69/e6f3d953f2fa0f8a723cf18cd011d52733bd7f6e045122b24e0e7f49f9b0/websockets-14.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:f95ba34d71e2fa0c5d225bde3b3bdb152e957150100e75c86bc7f3964c450d89", size = 168529 }, + { url = "https://files.pythonhosted.org/packages/70/ff/f31fa14561fc1d7b8663b0ed719996cf1f581abee32c8fb2f295a472f268/websockets-14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9481a6de29105d73cf4515f2bef8eb71e17ac184c19d0b9918a3701c6c9c4f23", size = 168475 }, + { url = "https://files.pythonhosted.org/packages/f1/15/b72be0e4bf32ff373aa5baef46a4c7521b8ea93ad8b49ca8c6e8e764c083/websockets-14.1-cp311-cp311-win32.whl", hash = "sha256:368a05465f49c5949e27afd6fbe0a77ce53082185bbb2ac096a3a8afaf4de52e", size = 162833 }, + { url = "https://files.pythonhosted.org/packages/bc/ef/2d81679acbe7057ffe2308d422f744497b52009ea8bab34b6d74a2657d1d/websockets-14.1-cp311-cp311-win_amd64.whl", hash = "sha256:6d24fc337fc055c9e83414c94e1ee0dee902a486d19d2a7f0929e49d7d604b09", size = 163263 }, + { url = "https://files.pythonhosted.org/packages/55/64/55698544ce29e877c9188f1aee9093712411a8fc9732cca14985e49a8e9c/websockets-14.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ed907449fe5e021933e46a3e65d651f641975a768d0649fee59f10c2985529ed", size = 161957 }, + { url = "https://files.pythonhosted.org/packages/a2/b1/b088f67c2b365f2c86c7b48edb8848ac27e508caf910a9d9d831b2f343cb/websockets-14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:87e31011b5c14a33b29f17eb48932e63e1dcd3fa31d72209848652310d3d1f0d", size = 159620 }, + { url = "https://files.pythonhosted.org/packages/c1/89/2a09db1bbb40ba967a1b8225b07b7df89fea44f06de9365f17f684d0f7e6/websockets-14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bc6ccf7d54c02ae47a48ddf9414c54d48af9c01076a2e1023e3b486b6e72c707", size = 159852 }, + { url = "https://files.pythonhosted.org/packages/ca/c1/f983138cd56e7d3079f1966e81f77ce6643f230cd309f73aa156bb181749/websockets-14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9777564c0a72a1d457f0848977a1cbe15cfa75fa2f67ce267441e465717dcf1a", size = 169675 }, + { url = "https://files.pythonhosted.org/packages/c1/c8/84191455d8660e2a0bdb33878d4ee5dfa4a2cedbcdc88bbd097303b65bfa/websockets-14.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a655bde548ca98f55b43711b0ceefd2a88a71af6350b0c168aa77562104f3f45", size = 168619 }, + { url = "https://files.pythonhosted.org/packages/8d/a7/62e551fdcd7d44ea74a006dc193aba370505278ad76efd938664531ce9d6/websockets-14.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3dfff83ca578cada2d19e665e9c8368e1598d4e787422a460ec70e531dbdd58", size = 169042 }, + { url = "https://files.pythonhosted.org/packages/ad/ed/1532786f55922c1e9c4d329608e36a15fdab186def3ca9eb10d7465bc1cc/websockets-14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6a6c9bcf7cdc0fd41cc7b7944447982e8acfd9f0d560ea6d6845428ed0562058", size = 169345 }, + { url = "https://files.pythonhosted.org/packages/ea/fb/160f66960d495df3de63d9bcff78e1b42545b2a123cc611950ffe6468016/websockets-14.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4b6caec8576e760f2c7dd878ba817653144d5f369200b6ddf9771d64385b84d4", size = 168725 }, + { url = "https://files.pythonhosted.org/packages/cf/53/1bf0c06618b5ac35f1d7906444b9958f8485682ab0ea40dee7b17a32da1e/websockets-14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb6d38971c800ff02e4a6afd791bbe3b923a9a57ca9aeab7314c21c84bf9ff05", size = 168712 }, + { url = "https://files.pythonhosted.org/packages/e5/22/5ec2f39fff75f44aa626f86fa7f20594524a447d9c3be94d8482cd5572ef/websockets-14.1-cp312-cp312-win32.whl", hash = "sha256:1d045cbe1358d76b24d5e20e7b1878efe578d9897a25c24e6006eef788c0fdf0", size = 162838 }, + { url = "https://files.pythonhosted.org/packages/74/27/28f07df09f2983178db7bf6c9cccc847205d2b92ced986cd79565d68af4f/websockets-14.1-cp312-cp312-win_amd64.whl", hash = "sha256:90f4c7a069c733d95c308380aae314f2cb45bd8a904fb03eb36d1a4983a4993f", size = 163277 }, + { url = "https://files.pythonhosted.org/packages/34/77/812b3ba5110ed8726eddf9257ab55ce9e85d97d4aa016805fdbecc5e5d48/websockets-14.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3630b670d5057cd9e08b9c4dab6493670e8e762a24c2c94ef312783870736ab9", size = 161966 }, + { url = "https://files.pythonhosted.org/packages/8d/24/4fcb7aa6986ae7d9f6d083d9d53d580af1483c5ec24bdec0978307a0f6ac/websockets-14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:36ebd71db3b89e1f7b1a5deaa341a654852c3518ea7a8ddfdf69cc66acc2db1b", size = 159625 }, + { url = "https://files.pythonhosted.org/packages/f8/47/2a0a3a2fc4965ff5b9ce9324d63220156bd8bedf7f90824ab92a822e65fd/websockets-14.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5b918d288958dc3fa1c5a0b9aa3256cb2b2b84c54407f4813c45d52267600cd3", size = 159857 }, + { url = "https://files.pythonhosted.org/packages/dd/c8/d7b425011a15e35e17757e4df75b25e1d0df64c0c315a44550454eaf88fc/websockets-14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00fe5da3f037041da1ee0cf8e308374e236883f9842c7c465aa65098b1c9af59", size = 169635 }, + { url = "https://files.pythonhosted.org/packages/93/39/6e3b5cffa11036c40bd2f13aba2e8e691ab2e01595532c46437b56575678/websockets-14.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8149a0f5a72ca36720981418eeffeb5c2729ea55fa179091c81a0910a114a5d2", size = 168578 }, + { url = "https://files.pythonhosted.org/packages/cf/03/8faa5c9576299b2adf34dcccf278fc6bbbcda8a3efcc4d817369026be421/websockets-14.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77569d19a13015e840b81550922056acabc25e3f52782625bc6843cfa034e1da", size = 169018 }, + { url = "https://files.pythonhosted.org/packages/8c/05/ea1fec05cc3a60defcdf0bb9f760c3c6bd2dd2710eff7ac7f891864a22ba/websockets-14.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cf5201a04550136ef870aa60ad3d29d2a59e452a7f96b94193bee6d73b8ad9a9", size = 169383 }, + { url = "https://files.pythonhosted.org/packages/21/1d/eac1d9ed787f80754e51228e78855f879ede1172c8b6185aca8cef494911/websockets-14.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:88cf9163ef674b5be5736a584c999e98daf3aabac6e536e43286eb74c126b9c7", size = 168773 }, + { url = "https://files.pythonhosted.org/packages/0e/1b/e808685530185915299740d82b3a4af3f2b44e56ccf4389397c7a5d95d39/websockets-14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:836bef7ae338a072e9d1863502026f01b14027250a4545672673057997d5c05a", size = 168757 }, + { url = "https://files.pythonhosted.org/packages/b6/19/6ab716d02a3b068fbbeb6face8a7423156e12c446975312f1c7c0f4badab/websockets-14.1-cp313-cp313-win32.whl", hash = "sha256:0d4290d559d68288da9f444089fd82490c8d2744309113fc26e2da6e48b65da6", size = 162834 }, + { url = "https://files.pythonhosted.org/packages/6c/fd/ab6b7676ba712f2fc89d1347a4b5bdc6aa130de10404071f2b2606450209/websockets-14.1-cp313-cp313-win_amd64.whl", hash = "sha256:8621a07991add373c3c5c2cf89e1d277e49dc82ed72c75e3afc74bd0acc446f0", size = 163277 }, + { url = "https://files.pythonhosted.org/packages/b0/0b/c7e5d11020242984d9d37990310520ed663b942333b83a033c2f20191113/websockets-14.1-py3-none-any.whl", hash = "sha256:4d4fc827a20abe6d544a119896f6b78ee13fe81cbfef416f3f2ddf09a03f0e2e", size = 156277 }, +] diff --git a/webapp/api/server.py b/webapp/api/server.py index 06df24eb..b15e02c1 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -1,4 +1,3 @@ -import json import os from collections.abc import Callable, Sequence from typing import Annotated @@ -333,6 +332,7 @@ def expand_project_with_task(project_name: str, task: DagNodeCreate) -> DagProje operator: The name of the operator whose task we'd like to use. task: The name of the operator's task to add as a node. default_task_kwargs: Any default arguments to pass to the task. + options: Any options to pass to the task, e.g. tools, response format. """ if project_name != task.project_name: raise HTTPException( @@ -421,7 +421,8 @@ async def run_project(project_name: str) -> list[tuple[str, str]]: node.name, node.task_name, getattr(operators, node.operator_name)(), - json.loads(node.default_task_kwargs), + node.default_task_kwargs, + node.options, ), ) for edge in edges: From a54cb5b7ab752a8df8eb11495c06b85d68e4e7a3 Mon Sep 17 00:00:00 2001 From: Dance Date: Mon, 18 Nov 2024 19:15:56 -0700 Subject: [PATCH 06/14] session refactor and bootstrap for webapp tests --- docs/developer-guide/database.md | 13 +- docs/sdk-reference/database.md | 9 +- src/concrete-core/concrete/tools/knowledge.py | 39 ++- src/concrete-db/concrete_db/__init__.py | 5 +- src/concrete-db/concrete_db/orm/__init__.py | 4 +- src/concrete-db/concrete_db/orm/setup.py | 11 - webapp/api/__init__.py | 0 webapp/api/server.py | 331 +++++++++--------- webapp/auth/server.py | 13 +- webapp/common.py | 8 + webapp/daemons/server.py | 5 +- webapp/main/server.py | 49 +-- webapp/tests/__init__.py | 0 webapp/tests/conftest.py | 36 ++ webapp/tests/test_api.py | 24 ++ 15 files changed, 297 insertions(+), 250 deletions(-) create mode 100644 webapp/api/__init__.py create mode 100644 webapp/tests/__init__.py create mode 100644 webapp/tests/conftest.py create mode 100644 webapp/tests/test_api.py diff --git a/docs/developer-guide/database.md b/docs/developer-guide/database.md index 6d031d16..b84c845f 100644 --- a/docs/developer-guide/database.md +++ b/docs/developer-guide/database.md @@ -22,25 +22,26 @@ class my_table(Base, table=True): ## DB Operations -Use `concrete.db.orm.Session` to get a session context manager. +Pass `concrete.db.orm.engine` to `sqlmodel.Session` to get a session context manager. Use this session to perform DB operations. Best practice is to use one session per one transaction. By default, sessions will not flush or commit. ```python -from concrete.db.orm import Session +from concrete.db.orm import engine +from sqlmodel import Session # The following solutions achieve the same thing, but with different approaches # ORM Centric solution def delete_my_table_orm(): - with Session() as session: + with Session(engine) as session: deleted_count = session.query(my_table).where(my_column == "my_value").delete() session.commit() return deleted_count def delete_my_table_core(): - with Session() as session: + with Session(engine) as session: stmt = delete(my_table).where(my_column == "my_value") result = session.execute(stmt) deleted_count = result.rowcount @@ -169,5 +170,5 @@ connection.execute( For a deeper dive, please examine the [alembic operations reference](https://alembic.sqlalchemy.org/en/latest/ops.html). -Last Updated: 2024-11-07 20:42:06 UTC -Lines Changed: +173, -0 +Last Updated: 2024-11-19 02:16:01 UTC +Lines Changed: +7, -6 diff --git a/docs/sdk-reference/database.md b/docs/sdk-reference/database.md index 5b87cf5f..28090f09 100644 --- a/docs/sdk-reference/database.md +++ b/docs/sdk-reference/database.md @@ -58,17 +58,18 @@ If you need to save objects manually, use a database `Session` object. The `Sess Example ```python -from concrete_db import Session +from concrete_db.orm import engine from concrete_db.orm.models import MetadataMixin, Base +from sqlmodel import Session class MyModel(Base, MetadataMixin): pass -with Session() as session: +with Session(engine) as session: my_model = MyModel() session.add(my_model) session.commit() ``` -Last Updated: 2024-11-08 15:12:55 UTC -Lines Changed: +72, -0 +Last Updated: 2024-11-19 02:16:01 UTC +Lines Changed: +5, -4 diff --git a/src/concrete-core/concrete/tools/knowledge.py b/src/concrete-core/concrete/tools/knowledge.py index 8cb19091..ef4cc71a 100644 --- a/src/concrete-core/concrete/tools/knowledge.py +++ b/src/concrete-core/concrete/tools/knowledge.py @@ -8,7 +8,8 @@ try: from concrete_db import crud - from concrete_db.orm import Session, models + from concrete_db.orm import engine, models + from sqlmodel import Session except ImportError: raise ImportError("Install concrete_db to use knowledge tools") @@ -57,7 +58,7 @@ def _parse_to_tree( branch=branch, ) - with Session() as db: + with Session(engine) as db: root_node_id = crud.create_repo_node(db=db, repo_node_create=root_node).id to_split.put(root_node_id) @@ -84,7 +85,7 @@ def _split_and_create_nodes(cls, parent_id: UUID, ignore_paths) -> list[UUID]: Chunks a node into smaller nodes. Adds children nodes to database, and returns them for further chunking. """ - with Session() as db: + with Session(engine) as db: parent = crud.get_repo_node(db=db, repo_node_id=parent_id) if parent is None: return [] @@ -119,7 +120,7 @@ def _split_and_create_nodes(cls, parent_id: UUID, ignore_paths) -> list[UUID]: res = [] for child in children: - with Session() as db: + with Session(engine) as db: child_node = crud.create_repo_node(db=db, repo_node_create=child) res.append(child_node.id) @@ -159,7 +160,7 @@ def _plot(cls, root_node_id: UUID): while not nodes.empty(): node_id = nodes.get() - with Session() as db: + with Session(engine) as db: node = crud.get_repo_node(db=db, repo_node_id=node_id) if node is None: continue @@ -215,7 +216,7 @@ def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): vert_gap = height / (max([level for level in levels]) + 1) return make_pos({}) - with Session() as db: + with Session(engine) as db: root_node = crud.get_repo_node(db=db, repo_node_id=root_node_id) if not root_node: db.close() @@ -245,7 +246,7 @@ def _upsert_all_summaries_from_leaves(cls, root_node_id: UUID): Prerequisite on the graph being built. """ node_ids: list[list[UUID]] = [[root_node_id]] # Stack of node ids in ascending order of depth. root -> leaf - with Session() as db: + with Session(engine) as db: while node_ids[-1] != []: to_append = [] for node_id in node_ids[-1]: @@ -270,7 +271,7 @@ def _upsert_parent_summaries_to_root(cls, child_id: UUID) -> None: """ Recursively updates all parent summaries up until the root. Prerequisite on the graph being built. """ - with Session() as db: + with Session(engine) as db: child = crud.get_repo_node(db=db, repo_node_id=child_id) if child is not None: parent_id = child.parent_id @@ -287,7 +288,7 @@ def _upsert_parent_summary_from_child(cls, child_node_id: UUID): """ from concrete_core.operators import Executive - with Session() as db: + with Session(engine) as db: child = crud.get_repo_node(db=db, repo_node_id=child_node_id) if child and child.parent_id: parent = crud.get_repo_node(db=db, repo_node_id=child.parent_id) @@ -310,7 +311,7 @@ def _upsert_parent_summary_from_child(cls, child_node_id: UUID): message_format=NodeSummary, ) # type: ignore - with Session() as db: + with Session(engine) as db: parent = crud.get_repo_node(db=db, repo_node_id=parent_id) if parent is not None: parent_overall_summary = node_summary.overall_summary @@ -336,7 +337,7 @@ def _upsert_leaf_summary(cls, leaf_node_id: UUID): import chardet from concrete_core.operators import Executive - with Session() as db: + with Session(engine) as db: leaf_node = crud.get_repo_node(db=db, repo_node_id=leaf_node_id) if leaf_node is not None and leaf_node.abs_path is not None and leaf_node.partition_type == "file": path = leaf_node.abs_path @@ -360,7 +361,7 @@ def _upsert_leaf_summary(cls, leaf_node_id: UUID): options={"message_format": ChildNodeSummary}, ) repo_node_create = models.RepoNodeUpdate(summary=child_node_summary.summary) - with Session() as db: + with Session(engine) as db: crud.update_repo_node(db=db, repo_node_id=leaf_node_id, repo_node_update=repo_node_create) @classmethod @@ -370,7 +371,7 @@ def _upsert_parent_summary_from_children(cls, repo_node_id: UUID): """ from concrete_core.operators import Executive - with Session() as db: + with Session(engine) as db: parent = crud.get_repo_node(db=db, repo_node_id=repo_node_id) if parent is None: raise ValueError(f"Node {repo_node_id} not found.") @@ -395,7 +396,7 @@ def _upsert_parent_summary_from_children(cls, repo_node_id: UUID): parent_node_update = models.RepoNodeUpdate( summary=overall_summary, children_summaries=parent_children_summaries ) - with Session() as db: + with Session(engine) as db: crud.update_repo_node(db=db, repo_node_id=repo_node_id, repo_node_update=parent_node_update) @classmethod @@ -404,7 +405,7 @@ def get_node_summary(cls, node_id: UUID) -> tuple[str, str]: Returns the summary of a node. (overall_summary, children_summaries) """ - with Session() as db: + with Session(engine) as db: node = crud.get_repo_node(db=db, repo_node_id=node_id) if node is None: return ("", "") @@ -415,7 +416,7 @@ def get_node_parent(cls, node_id: UUID) -> UUID | None: """ Returns the UUID of the parent of node (if it exists, else None). """ - with Session() as db: + with Session(engine) as db: node = crud.get_repo_node(db=db, repo_node_id=node_id) if node is None or node.parent_id is None: return None @@ -427,7 +428,7 @@ def get_node_children(cls, node_id: UUID) -> dict[str, UUID]: Returns the UUIDs of the children of a node. {child_name: child_id} """ - with Session() as db: + with Session(engine) as db: node = crud.get_repo_node(db=db, repo_node_id=node_id) if node is None: return {} @@ -439,7 +440,7 @@ def get_node_path(cls, node_id: UUID) -> str: """ Returns the abs_path attribute of a node. """ - with Session() as db: + with Session(engine) as db: node = crud.get_repo_node(db=db, repo_node_id=node_id) if node is None: return "" @@ -451,7 +452,7 @@ def _get_node_by_path(cls, org: str, repo: str, branch: str, path: str | None = Returns the UUID of a node by its path. Enables file pointer lookup. If path is none, returns the root node """ - with Session() as db: + with Session(engine) as db: if path is None: node = crud.get_root_repo_node(db=db, org=org, repo=repo, branch=branch) else: diff --git a/src/concrete-db/concrete_db/__init__.py b/src/concrete-db/concrete_db/__init__.py index d29467fc..66f3ca9d 100644 --- a/src/concrete-db/concrete_db/__init__.py +++ b/src/concrete-db/concrete_db/__init__.py @@ -4,9 +4,10 @@ from concrete.abstract import AbstractOperator from concrete.clients import CLIClient from concrete.models.messages import Message +from sqlmodel import Session from .crud import create_message -from .orm import Session +from .orm import engine from .orm.models import MessageCreate @@ -20,7 +21,7 @@ def decorator( ) -> Message: answer = _qna(self, query, response_format, instructions) if self.store_messages: - with Session() as session: + with Session(engine) as session: create_message( session, MessageCreate( diff --git a/src/concrete-db/concrete_db/orm/__init__.py b/src/concrete-db/concrete_db/orm/__init__.py index 5d128db4..38cf324d 100644 --- a/src/concrete-db/concrete_db/orm/__init__.py +++ b/src/concrete-db/concrete_db/orm/__init__.py @@ -1,3 +1,3 @@ -from .setup import Session +from .setup import engine -__all__ = ["Session"] +__all__ = ["engine"] diff --git a/src/concrete-db/concrete_db/orm/setup.py b/src/concrete-db/concrete_db/orm/setup.py index 709e9cc9..6c7e88b4 100644 --- a/src/concrete-db/concrete_db/orm/setup.py +++ b/src/concrete-db/concrete_db/orm/setup.py @@ -1,10 +1,8 @@ import os -from contextlib import contextmanager from concrete.clients import CLIClient from dotenv import load_dotenv from sqlalchemy import URL -from sqlmodel import Session as SQLModelSession from sqlmodel import create_engine load_dotenv(override=True) @@ -36,12 +34,3 @@ connect_args = {} engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args=connect_args) - - -@contextmanager -def Session(): - session = SQLModelSession(engine) - try: - yield session - finally: - session.close() diff --git a/webapp/api/__init__.py b/webapp/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/webapp/api/server.py b/webapp/api/server.py index b15e02c1..e022ca31 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -7,7 +7,6 @@ from concrete.projects import DAGNode, Project from concrete.webutils import AuthMiddleware from concrete_db import crud -from concrete_db.orm import Session from concrete_db.orm.models import ( Client, ClientCreate, @@ -31,6 +30,8 @@ from concrete import operators +from ..common import DbDep + dotenv.load_dotenv(override=True) UNAUTHENTICATED_PATHS = {"/ping", "/docs", "/redoc", "/openapi.json", "/favicon.ico"} @@ -93,46 +94,41 @@ def ping(): @app.post("/orchestrators/", response_model=Orchestrator) -def create_orchestrator(orchestrator: OrchestratorCreate) -> Orchestrator: - with Session() as db: - return crud.create_orchestrator(db, orchestrator) +def create_orchestrator(orchestrator: OrchestratorCreate, db: DbDep) -> Orchestrator: + return crud.create_orchestrator(db, orchestrator) @app.get("/orchestrators/") -def get_orchestrators(common_read_params: CommonReadDep) -> Sequence[Orchestrator]: - with Session() as db: - return crud.get_orchestrators( - db, - skip=common_read_params.skip, - limit=common_read_params.limit, - ) +def get_orchestrators(common_read_params: CommonReadDep, db: DbDep) -> Sequence[Orchestrator]: + return crud.get_orchestrators( + db, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) @app.get("/orchestrators/{orchestrator_id}") -def get_orchestrator(orchestrator_id: UUID) -> Orchestrator: - with Session() as db: - orchestrator = crud.get_orchestrator(db, orchestrator_id) - if orchestrator is None: - raise orchestrator_not_found(orchestrator_id) - return orchestrator +def get_orchestrator(orchestrator_id: UUID, db: DbDep) -> Orchestrator: + orchestrator = crud.get_orchestrator(db, orchestrator_id) + if orchestrator is None: + raise orchestrator_not_found(orchestrator_id) + return orchestrator @app.put("/orchestrators/{orchestrator_id}") -def update_orchestrator(orchestrator_id: UUID, orchestrator: OrchestratorUpdate) -> Orchestrator: - with Session() as db: - db_orc = crud.update_orchestrator(db, orchestrator_id, orchestrator) - if db_orc is None: - raise orchestrator_not_found(orchestrator_id) - return db_orc +def update_orchestrator(orchestrator_id: UUID, orchestrator: OrchestratorUpdate, db: DbDep) -> Orchestrator: + db_orc = crud.update_orchestrator(db, orchestrator_id, orchestrator) + if db_orc is None: + raise orchestrator_not_found(orchestrator_id) + return db_orc @app.delete("/orchestrators/{orchestrator_id}") -def delete_orchestrator(orchestrator_id: UUID) -> Orchestrator: - with Session() as db: - orchestrator = crud.delete_orchestrator(db, orchestrator_id) - if orchestrator is None: - raise orchestrator_not_found(orchestrator_id) - return orchestrator +def delete_orchestrator(orchestrator_id: UUID, db: DbDep) -> Orchestrator: + orchestrator = crud.delete_orchestrator(db, orchestrator_id) + if orchestrator is None: + raise orchestrator_not_found(orchestrator_id) + return orchestrator # endregion @@ -140,63 +136,58 @@ def delete_orchestrator(orchestrator_id: UUID) -> Orchestrator: @app.post("/operators/") -def create_operator(operator: OperatorCreate) -> Operator: - with Session() as db: - orchestrator = crud.get_orchestrator(db, operator.orchestrator_id) - if orchestrator is None: - raise orchestrator_not_found(operator.orchestrator_id) - return crud.create_operator(db, operator) +def create_operator(operator: OperatorCreate, db: DbDep) -> Operator: + orchestrator = crud.get_orchestrator(db, operator.orchestrator_id) + if orchestrator is None: + raise orchestrator_not_found(operator.orchestrator_id) + return crud.create_operator(db, operator) @app.get("/operators/") -def read_operators(common_read_params: CommonReadDep) -> Sequence[Operator]: - with Session() as db: - return crud.get_operators( - db, - skip=common_read_params.skip, - limit=common_read_params.limit, - ) +def read_operators(common_read_params: CommonReadDep, db: DbDep) -> Sequence[Operator]: + return crud.get_operators( + db, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) @app.get("/orchestrators/{orchestrator_id}/operators/") def read_orchestrator_operators( orchestrator_id: UUID, common_read_params: CommonReadDep, + db: DbDep, ) -> Sequence[Operator]: - with Session() as db: - return crud.get_operators( - db, - orchestrator_id, - common_read_params.skip, - common_read_params.limit, - ) + return crud.get_operators( + db, + orchestrator_id, + common_read_params.skip, + common_read_params.limit, + ) @app.get("/orchestrators/{orchestrator_id}/operators/{operator_id}") -def read_operator(orchestrator_id: UUID, operator_id: UUID) -> Operator: - with Session() as db: - operator = crud.get_operator(db, operator_id, orchestrator_id) - if operator is None: - raise operator_not_found(operator_id) - return operator +def read_operator(orchestrator_id: UUID, operator_id: UUID, db: DbDep) -> Operator: + operator = crud.get_operator(db, operator_id, orchestrator_id) + if operator is None: + raise operator_not_found(operator_id) + return operator @app.put("/orchestrators/{orchestrator_id}/operators/{operator_id}") -def update_operator(orchestrator_id: UUID, operator_id: UUID, operator: OperatorUpdate) -> Operator: - with Session() as db: - db_operator = crud.update_operator(db, operator_id, orchestrator_id, operator) - if db_operator is None: - raise operator_not_found(operator_id) - return db_operator +def update_operator(orchestrator_id: UUID, operator_id: UUID, operator: OperatorUpdate, db: DbDep) -> Operator: + db_operator = crud.update_operator(db, operator_id, orchestrator_id, operator) + if db_operator is None: + raise operator_not_found(operator_id) + return db_operator @app.delete("/orchestrators/{orchestrator_id}/operators/{operator_id}") -def delete_operator(orchestrator_id: UUID, operator_id: UUID) -> Operator: - with Session() as db: - operator = crud.delete_operator(db, operator_id, orchestrator_id) - if operator is None: - raise operator_not_found(operator_id) - return operator +def delete_operator(orchestrator_id: UUID, operator_id: UUID, db: DbDep) -> Operator: + operator = crud.delete_operator(db, operator_id, orchestrator_id) + if operator is None: + raise operator_not_found(operator_id) + return operator # endregion @@ -204,22 +195,20 @@ def delete_operator(orchestrator_id: UUID, operator_id: UUID) -> Operator: @app.post("/clients/") -def create_client(client: ClientCreate) -> Client: - with Session() as db: - operator = crud.get_operator(db, client.operator_id, client.orchestrator_id) - if operator is None: - raise operator_not_found(client.operator_id) - return crud.create_client(db, client) +def create_client(client: ClientCreate, db: DbDep) -> Client: + operator = crud.get_operator(db, client.operator_id, client.orchestrator_id) + if operator is None: + raise operator_not_found(client.operator_id) + return crud.create_client(db, client) @app.get("/clients/") -def read_clients(common_read_params: CommonReadDep) -> Sequence[Client]: - with Session() as db: - return crud.get_clients( - db, - skip=common_read_params.skip, - limit=common_read_params.limit, - ) +def read_clients(common_read_params: CommonReadDep, db: DbDep) -> Sequence[Client]: + return crud.get_clients( + db, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) @app.get("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/") @@ -227,42 +216,45 @@ def read_operator_clients( orchestrator_id: UUID, operator_id: UUID, common_read_params: CommonReadDep, + db: DbDep, ) -> Sequence[Client]: - with Session() as db: - return crud.get_clients( - db, - orchestrator_id=orchestrator_id, - operator_id=operator_id, - skip=common_read_params.skip, - limit=common_read_params.limit, - ) + return crud.get_clients( + db, + orchestrator_id=orchestrator_id, + operator_id=operator_id, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) @app.get("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") -def read_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> Client: - with Session() as db: - client = crud.get_client(db, client_id, operator_id, orchestrator_id) - if client is None: - raise client_not_found(client_id) - return client +def read_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID, db: DbDep) -> Client: + client = crud.get_client(db, client_id, operator_id, orchestrator_id) + if client is None: + raise client_not_found(client_id) + return client @app.put("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") -def update_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID, client: ClientUpdate) -> Client: - with Session() as db: - db_client = crud.update_client(db, client_id, operator_id, orchestrator_id, client) - if db_client is None: - raise client_not_found(client_id) - return db_client +def update_client( + orchestrator_id: UUID, + operator_id: UUID, + client_id: UUID, + client: ClientUpdate, + db: DbDep, +) -> Client: + db_client = crud.update_client(db, client_id, operator_id, orchestrator_id, client) + if db_client is None: + raise client_not_found(client_id) + return db_client @app.delete("/orchestrators/{orchestrator_id}/operators/{operator_id}/clients/{client_id}") -def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> Client: - with Session() as db: - client = crud.delete_client(db, client_id, operator_id, orchestrator_id) - if client is None: - raise client_not_found(client_id) - return client +def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID, db: DbDep) -> Client: + client = crud.delete_client(db, client_id, operator_id, orchestrator_id) + if client is None: + raise client_not_found(client_id) + return client # endregion @@ -271,7 +263,7 @@ def delete_client(orchestrator_id: UUID, operator_id: UUID, client_id: UUID) -> @app.post("/projects/dag/") -def initialize_project(project: DagProjectCreate) -> DagProject: +def initialize_project(project: DagProjectCreate, db: DbDep) -> DagProject: """ Initiate a directed-acyclic-graph (DAG) project locally. Projects must be unique in name. @@ -285,45 +277,41 @@ def initialize_project(project: DagProjectCreate) -> DagProject: name: The name of the project to be initialized. """ name = project.name - with Session() as db: - db_project = crud.get_dag_project_by_name(db, name) - if db_project is not None: - raise HTTPException(status_code=400, detail=f"{name} already exists as a Project!") - db_project = crud.create_dag_project(db, project) + db_project = crud.get_dag_project_by_name(db, name) + if db_project is not None: + raise HTTPException(status_code=400, detail=f"{name} already exists as a Project!") + db_project = crud.create_dag_project(db, project) - return db_project + return db_project @app.get("/projects/dag/") -def read_projects(common_read_params: CommonReadDep) -> Sequence[DagProject]: - with Session() as db: - return crud.get_dag_projects( - db, - skip=common_read_params.skip, - limit=common_read_params.limit, - ) +def read_projects(common_read_params: CommonReadDep, db: DbDep) -> Sequence[DagProject]: + return crud.get_dag_projects( + db, + skip=common_read_params.skip, + limit=common_read_params.limit, + ) @app.get("/projects/dag/{project_name}") -def read_project(project_name: str) -> DagProject: - with Session() as db: - project = crud.get_dag_project_by_name(db, project_name) - if project is None: - raise project_not_found(project_name) - return project +def read_project(project_name: str, db: DbDep) -> DagProject: + project = crud.get_dag_project_by_name(db, project_name) + if project is None: + raise project_not_found(project_name) + return project @app.delete("/projects/dag/{project_name}") -def delete_project(project_name: str) -> DagProject: - with Session() as db: - project = crud.delete_dag_project_by_name(db, project_name) - if project is None: - raise project_not_found(project_name) - return project +def delete_project(project_name: str, db: DbDep) -> DagProject: + project = crud.delete_dag_project_by_name(db, project_name) + if project is None: + raise project_not_found(project_name) + return project @app.post("/projects/dag/{project_name}/tasks") -def expand_project_with_task(project_name: str, task: DagNodeCreate) -> DagProject: +def expand_project_with_task(project_name: str, task: DagNodeCreate, db: DbDep) -> DagProject: """ Expand a project by adding an operator task as a node in its DAG. @@ -340,31 +328,28 @@ def expand_project_with_task(project_name: str, task: DagNodeCreate) -> DagProje detail=f"Path project name {project_name} and body project name {task.project_name} don't match!", ) - with Session() as db: - project = crud.get_dag_project_by_name(db, task.project_name) - if project is None: - raise project_not_found(task.project_name) - - node = crud.get_dag_node_by_name(db, project.id, task.name) - if node is not None: - raise HTTPException( - status_code=400, detail=f"{task.name} already exists as a node for {task.project_name}!" - ) - - crud.create_dag_node( - db, - DagNodeBase( - project_id=project.id, - **task.model_dump(exclude=set("project")), - ), - ) + project = crud.get_dag_project_by_name(db, task.project_name) + if project is None: + raise project_not_found(task.project_name) + + node = crud.get_dag_node_by_name(db, project.id, task.name) + if node is not None: + raise HTTPException(status_code=400, detail=f"{task.name} already exists as a node for {task.project_name}!") - db.refresh(project) - return project + crud.create_dag_node( + db, + DagNodeBase( + project_id=project.id, + **task.model_dump(exclude=set("project")), + ), + ) + + db.refresh(project) + return project @app.post("/projects/dag/{project_name}/edges") -def expand_project_with_connection(project_name: str, edge: DagNodeToDagNodeLink) -> DagProject: +def expand_project_with_connection(project_name: str, edge: DagNodeToDagNodeLink, db: DbDep) -> DagProject: """ Expand a project by connecting two tasks together. The output from the parent task will be fed into the child task. @@ -380,38 +365,36 @@ def expand_project_with_connection(project_name: str, edge: DagNodeToDagNodeLink detail=f"Path project name {project_name} and body project name {edge.project_name} don't match!", ) - with Session() as db: - project = crud.get_dag_project_by_name(db, edge.project_name) - if project is None: - raise project_not_found(edge.project_name) + project = crud.get_dag_project_by_name(db, edge.project_name) + if project is None: + raise project_not_found(edge.project_name) - db_edge = crud.get_dag_edge(db, edge.project_name, edge.parent_name, edge.child_name) - if db_edge is not None: - raise HTTPException( - status_code=400, - detail=f"{edge.project_name} already has an edge from {edge.parent_name} to {edge.child_name}!", - ) + db_edge = crud.get_dag_edge(db, edge.project_name, edge.parent_name, edge.child_name) + if db_edge is not None: + raise HTTPException( + status_code=400, + detail=f"{edge.project_name} already has an edge from {edge.parent_name} to {edge.child_name}!", + ) - crud.create_dag_edge(db, edge) + crud.create_dag_edge(db, edge) - db.refresh(project) - return project + db.refresh(project) + return project @app.post("/projects/dag/{project_name}/run") -async def run_project(project_name: str) -> list[tuple[str, str]]: +async def run_project(project_name: str, db: DbDep) -> list[tuple[str, str]]: """ Run a project from its sources to its sinks. project: The name of the project to be run. """ # TODO: error handling for cycles - with Session() as session: - db_project = crud.get_dag_project_by_name(session, project_name) - if db_project is None: - raise project_not_found(project_name) - nodes = db_project.nodes - edges = db_project.edges + db_project = crud.get_dag_project_by_name(db, project_name) + if db_project is None: + raise project_not_found(project_name) + nodes = db_project.nodes + edges = db_project.edges project = Project() for node in nodes: diff --git a/webapp/auth/server.py b/webapp/auth/server.py index 30c59892..47d80587 100644 --- a/webapp/auth/server.py +++ b/webapp/auth/server.py @@ -30,12 +30,13 @@ get_authstate, get_user, ) +from concrete_db.orm import engine from concrete_db.orm.models import AuthStateCreate, AuthTokenCreate, UserCreate -from concrete_db.orm.setup import Session from fastapi import FastAPI, HTTPException, Request, status from fastapi.responses import JSONResponse, RedirectResponse from google_auth_oauthlib.flow import Flow from oauthlib.oauth2.rfc6749.errors import InvalidGrantError +from sqlmodel import Session from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -120,7 +121,7 @@ def login(request: Request, destination_url: str | None = None): ) auth_state = AuthStateCreate(state=state, destination_url=clean_destination_url) - with Session() as session: + with Session(engine) as session: create_authstate(session, auth_state) return RedirectResponse(authorization_url) @@ -149,7 +150,7 @@ def auth_callback(request: Request): ) if (state := query_params.get("state")) is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - with Session() as session: + with Session(engine) as session: auth_state = get_authstate(session, state) if auth_state is None: raise HTTPException( @@ -189,7 +190,7 @@ def auth_callback(request: Request): # user_info_service = build('oauth2', 'v2', credentials=flow.credentials) # user_info = user_info_service.userinfo().get().execute() - with Session() as session: + with Session(engine) as session: user = get_user(session, user_info["email"]) if user is None: @@ -200,12 +201,12 @@ def auth_callback(request: Request): email=user_info["email"], profile_picture_url=user_info["picture"], ) - with Session() as session: + with Session(engine) as session: user = create_user(session, new_user) # Start saving refresh tokens for later. Only given to us for the first auth. auth_token = AuthTokenCreate(refresh_token=flow.credentials.refresh_token, user_id=user.id) - with Session() as session: + with Session(engine) as session: create_authtoken(session, auth_token) # Not strictly necessary as of now diff --git a/webapp/common.py b/webapp/common.py index fb79f4cb..5c47e78b 100644 --- a/webapp/common.py +++ b/webapp/common.py @@ -1,7 +1,9 @@ from typing import Annotated, Any from uuid import UUID +from concrete_db.orm import engine from fastapi import Depends, Request, WebSocket +from sqlmodel import Session class ConnectionManager: @@ -63,6 +65,12 @@ async def get_user_email_from_request(request: Request) -> str: return request.session["user"]["email"] +def get_session(): + with Session(engine) as session: + yield session + + +DbDep = Annotated[Session, Depends(get_session)] UserIdDep = Annotated[UUID, Depends(get_user_id_from_request)] UserIdDepWS = Annotated[UUID, Depends(get_user_id_from_ws)] UserEmailDep = Annotated[str, Depends(get_user_email_from_request)] diff --git a/webapp/daemons/server.py b/webapp/daemons/server.py index 81fd862e..6ed53318 100644 --- a/webapp/daemons/server.py +++ b/webapp/daemons/server.py @@ -15,12 +15,13 @@ from concrete.tools.http import RestApiTool from concrete.tools.knowledge import KnowledgeGraphTool from concrete_db import crud -from concrete_db.orm import Session +from concrete_db.orm import engine from dotenv import load_dotenv from fastapi import BackgroundTasks, FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from sqlmodel import Session app = FastAPI() templates = Jinja2Templates(directory="templates") @@ -359,7 +360,7 @@ def recommend_documentation(self, branch: str, path: str) -> tuple[str, str]: if not found: return ("", "") - with Session() as db: + with Session(engine) as db: documentation_node = crud.get_repo_node(db=db, repo_node_id=documentation_node_id) if documentation_node is None: CLIClient.emit(f"Documentation node not found for {path}") diff --git a/webapp/main/server.py b/webapp/main/server.py index 482c6c12..1ef1f9f3 100644 --- a/webapp/main/server.py +++ b/webapp/main/server.py @@ -10,7 +10,7 @@ from concrete.orchestrators import SoftwareOrchestrator from concrete.webutils import AuthMiddleware from concrete_db import crud -from concrete_db.orm import Session +from concrete_db.orm import engine from concrete_db.orm.models import ( Base, MessageCreate, @@ -32,6 +32,7 @@ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates +from sqlmodel import Session from starlette.middleware.sessions import SessionMiddleware from webapp.common import ( @@ -82,7 +83,7 @@ def sidebar_create( def sidebar_create_orchestrator(request: Request, user_email: str): - with Session() as session: + with Session(engine) as session: user_tools = crud.get_user_tools(session, user_email) tool_names = [tool.name for tool in user_tools] @@ -97,7 +98,7 @@ def sidebar_create_orchestrator(request: Request, user_email: str): def sidebar_create_operator(orchestrator_id: UUID, request: Request, user_id: UUID): - with Session() as session: + with Session(engine) as session: orchestrator_tools = crud.get_orchestrator_tools(session, orchestrator_id, user_id) tool_names = [tool.name for tool in orchestrator_tools] return sidebar_create( @@ -111,7 +112,7 @@ def sidebar_create_operator(orchestrator_id: UUID, request: Request, user_id: UU def sidebar_create_project(orchestrator_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: operators = crud.get_operators(session, orchestrator_id) CLIClient.emit_sequence(operators) CLIClient.emit("\n") @@ -194,7 +195,7 @@ async def login(request: Request): @app.get("/", response_class=HTMLResponse) async def root(request: Request, user_id: UserIdDep): # TODO interactive tool creation. - with Session() as session: + with Session(engine) as session: tools_to_add = ["HTTPTool", "Arithmetic"] for tool_name in tools_to_add: tool_create = ToolCreate(name=tool_name) @@ -223,7 +224,7 @@ async def get_changelog(request: Request): @app.get("/orchestrators", response_class=HTMLResponse) async def get_orchestrator_list(request: Request, user_id: UserIdDep): - with Session() as session: + with Session(engine) as session: orchestrators = crud.get_orchestrators(session, user_id) CLIClient.emit_sequence(orchestrators) CLIClient.emit("\n") @@ -252,7 +253,7 @@ async def create_orchestrator( ) # Create orchestrator with tools assigned to it - with Session() as session: + with Session(engine) as session: orchestrator = crud.create_orchestrator(session, orchestrator_create) CLIClient.emit(f"Creating {orchestrator=}\n") CLIClient.emit(f"Assigning tools: {tool_names=}\n") @@ -272,7 +273,7 @@ async def validate_orchestrator_name( name: annotatedFormStr = "", ): def db_getter(): - with Session() as session: + with Session(engine) as session: return crud.get_orchestrator_by_name(session, name, user_id) return create_name_validation(name, db_getter, request) @@ -285,7 +286,7 @@ async def create_orchestrator_form(request: Request, user_email: UserEmailDep): @app.get("/orchestrators/{orchestrator_id}", response_class=HTMLResponse) async def get_orchestrator(orchestrator_id: UUID, request: Request, user_id: UserIdDep): - with Session() as session: + with Session(engine) as session: orchestrator = crud.get_orchestrator(session, orchestrator_id, user_id) return templates.TemplateResponse( name="orchestrator.html", @@ -298,7 +299,7 @@ async def get_orchestrator(orchestrator_id: UUID, request: Request, user_id: Use @app.delete("/orchestrators/{orchestrator_id}") async def delete_orchestrator(orchestrator_id: UUID, user_id: UserIdDep): - with Session() as session: + with Session(engine) as session: orchestrator = crud.delete_orchestrator(session, orchestrator_id, user_id) CLIClient.emit(f"{orchestrator}\n") headers = {"HX-Trigger": "getOrchestrators"} @@ -308,7 +309,7 @@ async def delete_orchestrator(orchestrator_id: UUID, user_id: UserIdDep): # === Operators === # @app.get("/orchestrators/{orchestrator_id}/operators", response_class=HTMLResponse) async def get_operator_list(orchestrator_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: operators = crud.get_operators(session, orchestrator_id) CLIClient.emit_sequence(operators) CLIClient.emit("\n") @@ -338,7 +339,7 @@ async def create_operator( title=title, orchestrator_id=orchestrator_id, ) - with Session() as session: + with Session(engine) as session: operator = crud.create_operator(session, operator_create) for tool_name in tool_names: tool = crud.get_tool_by_name(session, user_id, tool_name) @@ -359,7 +360,7 @@ async def validate_operator_name( name: annotatedFormStr = "", ): def db_getter(): - with Session() as session: + with Session(engine) as session: return crud.get_operator_by_name(session, name, orchestrator_id) return create_name_validation(name, db_getter, request) @@ -375,7 +376,7 @@ async def create_operator_form(orchestrator_id: UUID, request: Request, user_id: response_class=HTMLResponse, ) async def get_operator(orchestrator_id: UUID, operator_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: operator = crud.get_operator(session, operator_id, orchestrator_id) return templates.TemplateResponse( name="operator.html", @@ -389,7 +390,7 @@ async def get_operator(orchestrator_id: UUID, operator_id: UUID, request: Reques @app.delete("/orchestrators/{orchestrator_id}/operators/{operator_id}") async def delete_operator(orchestrator_id: UUID, operator_id: UUID): # TODO: generate error feedback for user when operator is in a group project - with Session() as session: + with Session(engine) as session: operator = crud.delete_operator(session, operator_id, orchestrator_id) CLIClient.emit(f"{operator}\n") headers = {"HX-Trigger": "getOperators"} @@ -401,7 +402,7 @@ async def delete_operator(orchestrator_id: UUID, operator_id: UUID): @app.get("/orchestrators/{orchestrator_id}/projects", response_class=HTMLResponse) async def get_project_list(orchestrator_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: projects = crud.get_projects(session, orchestrator_id) CLIClient.emit_sequence(projects) CLIClient.emit("\n") @@ -429,7 +430,7 @@ async def create_project( developer_id=developer_id, orchestrator_id=orchestrator_id, ) - with Session() as session: + with Session(engine) as session: project = crud.create_project(session, project_create) CLIClient.emit(f"{project}\n") return sidebar_create_project(orchestrator_id, request) @@ -442,7 +443,7 @@ async def validate_project_name( name: annotatedFormStr = "", ): def db_getter(): - with Session() as session: + with Session(engine) as session: return crud.get_project_by_name(session, name, orchestrator_id) return create_name_validation(name, db_getter, request) @@ -458,7 +459,7 @@ async def create_project_form(orchestrator_id: UUID, request: Request): response_class=HTMLResponse, ) async def get_project(orchestrator_id: UUID, project_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: project = crud.get_project(session, project_id, orchestrator_id) CLIClient.emit(f"{project}\n") return templates.TemplateResponse( @@ -472,7 +473,7 @@ async def get_project(orchestrator_id: UUID, project_id: UUID, request: Request) @app.delete("/orchestrators/{orchestrator_id}/projects/{project_id}") async def delete_project(orchestrator_id: UUID, project_id: UUID): - with Session() as session: + with Session(engine) as session: project = crud.delete_project(session, project_id, orchestrator_id) CLIClient.emit(f"{project}\n") headers = {"HX-Trigger": "getProjects"} @@ -484,7 +485,7 @@ async def delete_project(orchestrator_id: UUID, project_id: UUID): response_class=HTMLResponse, ) async def get_project_chat(orchestrator_id: UUID, project_id: UUID, request: Request): - with Session() as session: + with Session(engine) as session: chat = crud.get_messages(session, project_id) CLIClient.emit_sequence(chat) CLIClient.emit("\n") @@ -510,7 +511,7 @@ async def get_project_chat(orchestrator_id: UUID, project_id: UUID, request: Req def get_project_is_done(project_id: UUID) -> bool: - with Session() as session: + with Session(engine) as session: final_message = crud.get_completed_project(session, project_id) return final_message is not None @@ -523,7 +524,7 @@ async def get_downloadable_completed_project(orchestrator_id, project_id: UUID) if not get_project_is_done(project_id): raise HTTPException(status_code=404, detail=f"Project {project_id} not completed yet!") - with Session() as session: + with Session(engine) as session: final_message = crud.get_completed_project(session, project_id) if final_message is not None: pydantic_message = final_message.to_obj() @@ -552,7 +553,7 @@ async def project_chat_ws(websocket: WebSocket, orchestrator_id: UUID, project_i prompt = data["prompt"] # TODO: Use concrete.messages.TextMessage and # more tightly-couple Pydantic models with SQLModel models - with Session() as session: + with Session(engine) as session: prompt_message = crud.create_message( session, MessageCreate( diff --git a/webapp/tests/__init__.py b/webapp/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/webapp/tests/conftest.py b/webapp/tests/conftest.py new file mode 100644 index 00000000..3cac93c2 --- /dev/null +++ b/webapp/tests/conftest.py @@ -0,0 +1,36 @@ +import pytest +from concrete_db.orm.models import SQLModel +from fastapi.testclient import TestClient +from sqlmodel import Session, create_engine +from sqlmodel.pool import StaticPool + +from ..common import get_session + + +@pytest.fixture(name="session") +def session_fixture(): + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + + +@pytest.fixture(name="client") +def client_fixture( + session: Session, + request, +): + def get_session_override(): + return session + + app = request.param + + app.dependency_overrides[get_session] = get_session_override + + client = TestClient(app) + yield client + app.dependency_overrides.clear() diff --git a/webapp/tests/test_api.py b/webapp/tests/test_api.py new file mode 100644 index 00000000..1c53c83e --- /dev/null +++ b/webapp/tests/test_api.py @@ -0,0 +1,24 @@ +import pytest +from fastapi.testclient import TestClient + +from ..api.server import app + + +@pytest.mark.parametrize("client", [app], indirect=True) +def test_initialize_project(client: TestClient): + response = client.post( + "/projects/dag/", + json={ + "name": "test_project", + }, + ) + + assert response.status_code == 401 + # TODO: remove auth or add to request body above + # data = response.json() + + # assert response.status_code == 200 + # assert data["name"] == "test_project" + # assert "id" in data + # assert data["created_at"] is not None + # assert data["modified_at"] is not None From 094c2675ec6011d57df5b6fe41f99c189fef6c17 Mon Sep 17 00:00:00 2001 From: Dance Date: Mon, 18 Nov 2024 23:22:34 -0700 Subject: [PATCH 07/14] bootstrap draw_mermaid --- .../concrete/projects/dag_project.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index bb7a3b61..98276c90 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -65,6 +65,152 @@ async def execute(self) -> AsyncGenerator[tuple[str, str], None]: if node_dep_count[child] == 0: no_dep_nodes.add(child) + def draw_mermaid( + self, + first_node: str | None = None, + last_node: str | None = None, + with_styles: bool = True, + curve_style: "CurveStyle" = None, + node_styles: "NodeStyles" = None, + wrap_label_n_words: int = 9, + ) -> str: + """Draws a Mermaid graph using the provided graph data. + Adapted from langchain_core.runnables.graph_mermaid.draw_mermaid + + Args: + first_node (str, optional): Id of the first node. Defaults to None. + last_node (str, optional): Id of the last node. Defaults to None. + with_styles (bool, optional): Whether to include styles in the graph. + Defaults to True. + curve_style (CurveStyle, optional): Curve style for the edges. + Defaults to CurveStyle.LINEAR. + node_styles (NodeStyles, optional): Node colors for different types. + Defaults to NodeStyles(). + wrap_label_n_words (int, optional): Words to wrap the edge labels. + Defaults to 9. + + Returns: + str: Mermaid graph syntax. + """ + pass + # Initialize Mermaid graph configuration + mermaid_graph = ( + ( + f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" + f"}}}}}}%%\ngraph TD;\n" + ) + if with_styles + else "graph TD;\n" + ) + + if with_styles: + # Node formatting templates + default_class_label = "default" + format_dict = {default_class_label: "{0}({1})"} + if first_node is not None: + format_dict[first_node] = "{0}([{1}]):::first" + if last_node is not None: + format_dict[last_node] = "{0}([{1}]):::last" + + # Add nodes to the graph + for key, node in nodes.items(): + node_name = node.name.split(":")[-1] + label = ( + f"

{node_name}

" + if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) + and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) + else node_name + ) + if node.metadata: + label = ( + f"{label}
" + + "\n".join( + f"{key} = {value}" for key, value in node.metadata.items() + ) + + "" + ) + node_label = format_dict.get(key, format_dict[default_class_label]).format( + _escape_node_label(key), label + ) + mermaid_graph += f"\t{node_label}\n" + + # Group edges by their common prefixes + edge_groups: dict[str, list[Edge]] = {} + for edge in edges: + src_parts = edge.source.split(":") + tgt_parts = edge.target.split(":") + common_prefix = ":".join( + src for src, tgt in zip(src_parts, tgt_parts) if src == tgt + ) + edge_groups.setdefault(common_prefix, []).append(edge) + + seen_subgraphs = set() + + def add_subgraph(edges: list[Edge], prefix: str) -> None: + nonlocal mermaid_graph + self_loop = len(edges) == 1 and edges[0].source == edges[0].target + if prefix and not self_loop: + subgraph = prefix.split(":")[-1] + if subgraph in seen_subgraphs: + msg = ( + f"Found duplicate subgraph '{subgraph}' -- this likely means that " + "you're reusing a subgraph node with the same name. " + "Please adjust your graph to have subgraph nodes with unique names." + ) + raise ValueError(msg) + + seen_subgraphs.add(subgraph) + mermaid_graph += f"\tsubgraph {subgraph}\n" + + for edge in edges: + source, target = edge.source, edge.target + + # Add BR every wrap_label_n_words words + if edge.data is not None: + edge_data = edge.data + words = str(edge_data).split() # Split the string into words + # Group words into chunks of wrap_label_n_words size + if len(words) > wrap_label_n_words: + edge_data = " 
 ".join( + " ".join(words[i : i + wrap_label_n_words]) + for i in range(0, len(words), wrap_label_n_words) + ) + if edge.conditional: + edge_label = f" -.  {edge_data}  .-> " + else: + edge_label = f" --  {edge_data}  --> " + else: + edge_label = " -.-> " if edge.conditional else " --> " + + mermaid_graph += ( + f"\t{_escape_node_label(source)}{edge_label}" + f"{_escape_node_label(target)};\n" + ) + + # Recursively add nested subgraphs + for nested_prefix in edge_groups: + if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: + continue + add_subgraph(edge_groups[nested_prefix], nested_prefix) + + if prefix and not self_loop: + mermaid_graph += "\tend\n" + + # Start with the top-level edges (no common prefix) + add_subgraph(edge_groups.get("", []), "") + + # Add remaining subgraphs + for prefix in edge_groups: + if ":" in prefix or prefix == "": + continue + add_subgraph(edge_groups[prefix], prefix) + + # Add custom styles for nodes + if with_styles: + mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles()) + return mermaid_graph + + @property def is_dag(self): # AI generated From bb23838f08e53e84e49551ffa97a17a72a686091 Mon Sep 17 00:00:00 2001 From: Dance Date: Tue, 19 Nov 2024 14:49:05 -0700 Subject: [PATCH 08/14] address michael comments --- pyproject.toml | 23 +++++++++---------- .../concrete/projects/dag_project.py | 8 +------ webapp/api/server.py | 8 +++---- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 04fc90fc..b50da90b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,23 +4,22 @@ dev = [ "jupyterlab", "black", - "flake8", - "isort", - "bandit", - "pre-commit", - "mypy", - "alembic", - "pytest", - "ipykernel", - "boto3", - "boto3-stubs[ecs]" + "flake8", + "isort", + "bandit", + "pre-commit", + "mypy", + "alembic", + "pytest", + "ipykernel", + "boto3", + "boto3-stubs[ecs]", + "fastapi[standard]>=0.115.4", ] - packages = [ "concrete-core", "concrete-async", "concrete-db", - "fastapi[standard]>=0.115.4", ] [tool.uv.workspace] diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index e1f1dccb..3cb2934c 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -42,13 +42,7 @@ def add_edge( return (parent, child, res_name) - def add_node(self, name: str, node: "DAGNode") -> "DAGNode": - if name != node.name: - node.name = name - if node.name == "" or node.name in self.nodes: - # TODO: implement random name generator bandit is happy with - # https://www.geeksforgeeks.org/python-generate-random-string-of-given-length/ does not fly - node.name = max(self.nodes, default="") + "1" + def add_node(self, node: "DAGNode") -> "DAGNode": self.nodes[node.name] = node return node diff --git a/webapp/api/server.py b/webapp/api/server.py index e022ca31..ab339aff 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -4,6 +4,7 @@ from uuid import UUID import dotenv +from concrete.clients import CLIClient from concrete.projects import DAGNode, Project from concrete.webutils import AuthMiddleware from concrete_db import crud @@ -399,14 +400,13 @@ async def run_project(project_name: str, db: DbDep) -> list[tuple[str, str]]: project = Project() for node in nodes: project.add_node( - node.name, DAGNode( node.name, node.task_name, getattr(operators, node.operator_name)(), node.default_task_kwargs, node.options, - ), + ) ) for edge in edges: project.add_edge( @@ -417,8 +417,8 @@ async def run_project(project_name: str, db: DbDep) -> list[tuple[str, str]]: result = [] async for operator, response in project.execute(): - print(operator) - print(response.text) + CLIClient.emit(operator) + CLIClient.emit(response.text) result.append((operator, response.text)) return result From d107e380dd9a23390a0f6ea0d8f9c45b0985aba4 Mon Sep 17 00:00:00 2001 From: Dance Date: Wed, 20 Nov 2024 16:15:48 -0700 Subject: [PATCH 09/14] fixed DB models, added hard coded transformation on project api --- .../fc83e32e33f5_create_dag_tables.py | 75 +++++++++++++++++++ .../concrete/projects/dag_project.py | 1 + src/concrete-db/concrete_db/crud.py | 4 +- src/concrete-db/concrete_db/orm/models.py | 32 ++++++-- webapp/api/server.py | 8 +- 5 files changed, 106 insertions(+), 14 deletions(-) create mode 100644 migrations/versions/fc83e32e33f5_create_dag_tables.py diff --git a/migrations/versions/fc83e32e33f5_create_dag_tables.py b/migrations/versions/fc83e32e33f5_create_dag_tables.py new file mode 100644 index 00000000..a1fca8f3 --- /dev/null +++ b/migrations/versions/fc83e32e33f5_create_dag_tables.py @@ -0,0 +1,75 @@ +"""create dag tables + +Revision ID: fc83e32e33f5 +Revises: 3bb0633b746d +Create Date: 2024-11-20 13:47:29.927238 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = 'fc83e32e33f5' +down_revision: Union[str, None] = '3bb0633b746d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'dagproject', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('modified_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name'), + ) + op.create_table( + 'dagnode', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('modified_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True), + sa.Column('project_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.Column('operator_name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.Column('task_name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False), + sa.Column('default_task_kwargs', sa.JSON(), nullable=True), + sa.Column('options', sa.JSON(), nullable=True), + sa.ForeignKeyConstraint(['project_name'], ['dagproject.name'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name', 'project_name', name='no_duplicate_names_per_project'), + ) + op.create_table( + 'dagnodetodagnodelink', + sa.Column('project_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('parent_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('child_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('input_to_child', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ['project_name', 'child_name'], ['dagnode.project_name', 'dagnode.name'], ondelete='CASCADE' + ), + sa.ForeignKeyConstraint( + ['project_name', 'parent_name'], ['dagnode.project_name', 'dagnode.name'], ondelete='CASCADE' + ), + sa.ForeignKeyConstraint(['project_name'], ['dagproject.name'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('project_name', 'parent_name', 'child_name'), + ) + op.create_index( + op.f('ix_dagnodetodagnodelink_project_name'), 'dagnodetodagnodelink', ['project_name'], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_dagnodetodagnodelink_project_name'), table_name='dagnodetodagnodelink') + op.drop_table('dagnodetodagnodelink') + op.drop_table('dagnode') + op.drop_table('dagproject') + # ### end Alembic commands ### diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index 3cb2934c..ecceed57 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -138,6 +138,7 @@ async def execute(self, options: dict = {}) -> Any: """ kwargs = self.default_task_kwargs | self.dynamic_kwargs options = self.options | options + print(kwargs) res = self.bound_task(**kwargs, options=self.options | options) if options.get("run_async"): res = res.get().message diff --git a/src/concrete-db/concrete_db/crud.py b/src/concrete-db/concrete_db/crud.py index eff75e12..3617f352 100644 --- a/src/concrete-db/concrete_db/crud.py +++ b/src/concrete-db/concrete_db/crud.py @@ -507,8 +507,8 @@ def get_dag_project_by_name(db: Session, name: str) -> DagProject | None: return db.scalars(stmt).first() -def get_dag_node_by_name(db: Session, project_id: UUID, node_name: str) -> DagNode | None: - stmt = select(DagNode).where(DagNode.project_id == project_id).where(DagNode.name == node_name) +def get_dag_node_by_name(db: Session, project_name: str, node_name: str) -> DagNode | None: + stmt = select(DagNode).where(DagNode.project_name == project_name).where(DagNode.name == node_name) return db.scalars(stmt).first() diff --git a/src/concrete-db/concrete_db/orm/models.py b/src/concrete-db/concrete_db/orm/models.py index 57e0cd54..b6a0ee74 100644 --- a/src/concrete-db/concrete_db/orm/models.py +++ b/src/concrete-db/concrete_db/orm/models.py @@ -10,7 +10,13 @@ from concrete.tools import MetaTool from concrete.tools.utils import tool_name_to_class from pydantic import ConfigDict, ValidationError, model_validator -from sqlalchemy import CheckConstraint, Column, DateTime, UniqueConstraint +from sqlalchemy import ( + CheckConstraint, + Column, + DateTime, + ForeignKeyConstraint, + UniqueConstraint, +) from sqlalchemy.schema import Index from sqlalchemy.sql import func from sqlmodel import JSON, Field, Relationship, SQLModel @@ -67,12 +73,24 @@ class UserToolLink(Base, table=True): class DagNodeToDagNodeLink(Base, table=True): project_name: str = Field(foreign_key="dagproject.name", primary_key=True, index=True, ondelete="CASCADE") - parent_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") - child_name: str = Field(foreign_key="dagnode.name", primary_key=True, ondelete="CASCADE") + parent_name: str = Field(primary_key=True) + child_name: str = Field(primary_key=True) input_to_child: str = Field(description="Name of the argument to the child task", default="message") project: "DagProject" = Relationship(back_populates="edges") + __table_args__ = ( + ForeignKeyConstraint( + ["project_name", "parent_name"], + ["dagnode.project_name", "dagnode.name"], + ondelete="CASCADE", + ), + ForeignKeyConstraint( + ["project_name", "child_name"], + ["dagnode.project_name", "dagnode.name"], + ondelete="CASCADE", + ), + ) # TODO maybe store transformation function @@ -364,9 +382,9 @@ class DagProject(DagProjectBase, MetadataMixin, table=True): class DagNodeBase(Base): - project_id: UUID = Field( - description="ID of DAG Project this DAG Node belongs to.", - foreign_key="dagproject.id", + project_name: str = Field( + description="Name of DAG Project this DAG Node belongs to.", + foreign_key="dagproject.name", ondelete="CASCADE", ) name: str = Field( @@ -393,7 +411,7 @@ class DagNodeBase(Base): sa_column=Column(JSON), ) - __table_args__ = (UniqueConstraint("name", "project_id", name="no_duplicate_names_per_project"),) + __table_args__ = (UniqueConstraint("name", "project_name", name="no_duplicate_names_per_project"),) # TODO: options diff --git a/webapp/api/server.py b/webapp/api/server.py index ab339aff..670b9f24 100644 --- a/webapp/api/server.py +++ b/webapp/api/server.py @@ -333,16 +333,13 @@ def expand_project_with_task(project_name: str, task: DagNodeCreate, db: DbDep) if project is None: raise project_not_found(task.project_name) - node = crud.get_dag_node_by_name(db, project.id, task.name) + node = crud.get_dag_node_by_name(db, project.name, task.name) if node is not None: raise HTTPException(status_code=400, detail=f"{task.name} already exists as a node for {task.project_name}!") crud.create_dag_node( db, - DagNodeBase( - project_id=project.id, - **task.model_dump(exclude=set("project")), - ), + DagNodeBase(**task.model_dump()), ) db.refresh(project) @@ -413,6 +410,7 @@ async def run_project(project_name: str, db: DbDep) -> list[tuple[str, str]]: edge.parent_name, edge.child_name, edge.input_to_child, + lambda x: x.text, # TODO: account for different message types ) result = [] From bb4c42d897b6b74f644f2c827539dc2b9e8ed330 Mon Sep 17 00:00:00 2001 From: Dance Date: Thu, 21 Nov 2024 02:07:22 -0700 Subject: [PATCH 10/14] first iteration on mermaid output --- src/concrete-core/concrete/mermaid.py | 8 + .../concrete/projects/dag_project.py | 180 +++++------------- src/concrete-core/concrete/utils.py | 70 ++++++- 3 files changed, 121 insertions(+), 137 deletions(-) create mode 100644 src/concrete-core/concrete/mermaid.py diff --git a/src/concrete-core/concrete/mermaid.py b/src/concrete-core/concrete/mermaid.py new file mode 100644 index 00000000..9f3ed030 --- /dev/null +++ b/src/concrete-core/concrete/mermaid.py @@ -0,0 +1,8 @@ +from enum import StrEnum + + +class FlowchartDirection(StrEnum): + LEFT_RIGHT = "LR" + RIGHT_LEFT = "RL" + TOP_DOWN = "TD" + BOTTOM_UP = "BT" diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index f512724c..4a9ed29c 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -2,8 +2,10 @@ from collections.abc import AsyncGenerator from typing import Any, Callable +from concrete.mermaid import FlowchartDirection from concrete.operators import Operator from concrete.state import StatefulMixin +from concrete.utils import bfs_traversal, find_sources_and_sinks class Project(StatefulMixin): @@ -70,149 +72,57 @@ async def execute(self) -> AsyncGenerator[tuple[str, str], None]: def draw_mermaid( self, - first_node: str | None = None, - last_node: str | None = None, - with_styles: bool = True, - curve_style: "CurveStyle" = None, - node_styles: "NodeStyles" = None, - wrap_label_n_words: int = 9, + title: str | None = None, + direction: FlowchartDirection = FlowchartDirection.TOP_DOWN, + start_nodes: list[str] = [], + end_nodes: list[str] = [], ) -> str: - """Draws a Mermaid graph using the provided graph data. - Adapted from langchain_core.runnables.graph_mermaid.draw_mermaid + """Draws a Mermaid flowchart from the DAG. Args: - first_node (str, optional): Id of the first node. Defaults to None. - last_node (str, optional): Id of the last node. Defaults to None. - with_styles (bool, optional): Whether to include styles in the graph. - Defaults to True. - curve_style (CurveStyle, optional): Curve style for the edges. - Defaults to CurveStyle.LINEAR. - node_styles (NodeStyles, optional): Node colors for different types. - Defaults to NodeStyles(). - wrap_label_n_words (int, optional): Words to wrap the edge labels. - Defaults to 9. + title (str, optional): Title of the flowchart. Defaults to None. + direction (FlowchartDirection, optional): + Direction of the flowchart, i.e. start and end positions. Defaults to top down. + start_nodes (list[str], optional): Names of the source (i.e. start) nodes. Defaults to project source nodes. + end_nodes (list[str], optional): Names of the sink (i.e. end) nodes. Defaults to project sink nodes. Returns: - str: Mermaid graph syntax. + str: Mermaid flowchart syntax. """ - pass - # Initialize Mermaid graph configuration - mermaid_graph = ( - ( - f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'" - f"}}}}}}%%\ngraph TD;\n" - ) - if with_styles - else "graph TD;\n" + flowchart = f"flowchart {direction}\n" + + if title is not None: + flowchart = flowchart + f"---\ntitle: {title}\n---\n" + + remove_whitespace: Callable[[str], str] = lambda string: "".join(string.split()) + get_child: Callable[[tuple[str, str, Callable]], str] = lambda edge: edge[0] + + def process_node(node: str) -> None: + nonlocal flowchart + flowchart = flowchart + f"\t{remove_whitespace(node)}([\"{self.nodes[node]!s}\"])\n" + + def process_edge(node: str, edge: tuple[str, str, Callable]) -> None: + # TODO: design a good string representation for result transformation + nonlocal flowchart + flowchart = flowchart + f"\t{remove_whitespace(node)} -->|{edge[1]}| {remove_whitespace(edge[0])}\n" + + if not start_nodes or not end_nodes: + sources, sinks = find_sources_and_sinks(self.nodes, self.edges, get_child) + if not start_nodes: + start_nodes = sources + if not end_nodes: + end_nodes = sinks + + bfs_traversal( + self.edges, + start_nodes, + end_nodes, + process_node=process_node, + process_edge=process_edge, + get_neighbor=get_child, ) - if with_styles: - # Node formatting templates - default_class_label = "default" - format_dict = {default_class_label: "{0}({1})"} - if first_node is not None: - format_dict[first_node] = "{0}([{1}]):::first" - if last_node is not None: - format_dict[last_node] = "{0}([{1}]):::last" - - # Add nodes to the graph - for key, node in nodes.items(): - node_name = node.name.split(":")[-1] - label = ( - f"

{node_name}

" - if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) - and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) - else node_name - ) - if node.metadata: - label = ( - f"{label}
" - + "\n".join( - f"{key} = {value}" for key, value in node.metadata.items() - ) - + "" - ) - node_label = format_dict.get(key, format_dict[default_class_label]).format( - _escape_node_label(key), label - ) - mermaid_graph += f"\t{node_label}\n" - - # Group edges by their common prefixes - edge_groups: dict[str, list[Edge]] = {} - for edge in edges: - src_parts = edge.source.split(":") - tgt_parts = edge.target.split(":") - common_prefix = ":".join( - src for src, tgt in zip(src_parts, tgt_parts) if src == tgt - ) - edge_groups.setdefault(common_prefix, []).append(edge) - - seen_subgraphs = set() - - def add_subgraph(edges: list[Edge], prefix: str) -> None: - nonlocal mermaid_graph - self_loop = len(edges) == 1 and edges[0].source == edges[0].target - if prefix and not self_loop: - subgraph = prefix.split(":")[-1] - if subgraph in seen_subgraphs: - msg = ( - f"Found duplicate subgraph '{subgraph}' -- this likely means that " - "you're reusing a subgraph node with the same name. " - "Please adjust your graph to have subgraph nodes with unique names." - ) - raise ValueError(msg) - - seen_subgraphs.add(subgraph) - mermaid_graph += f"\tsubgraph {subgraph}\n" - - for edge in edges: - source, target = edge.source, edge.target - - # Add BR every wrap_label_n_words words - if edge.data is not None: - edge_data = edge.data - words = str(edge_data).split() # Split the string into words - # Group words into chunks of wrap_label_n_words size - if len(words) > wrap_label_n_words: - edge_data = " 
 ".join( - " ".join(words[i : i + wrap_label_n_words]) - for i in range(0, len(words), wrap_label_n_words) - ) - if edge.conditional: - edge_label = f" -.  {edge_data}  .-> " - else: - edge_label = f" --  {edge_data}  --> " - else: - edge_label = " -.-> " if edge.conditional else " --> " - - mermaid_graph += ( - f"\t{_escape_node_label(source)}{edge_label}" - f"{_escape_node_label(target)};\n" - ) - - # Recursively add nested subgraphs - for nested_prefix in edge_groups: - if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: - continue - add_subgraph(edge_groups[nested_prefix], nested_prefix) - - if prefix and not self_loop: - mermaid_graph += "\tend\n" - - # Start with the top-level edges (no common prefix) - add_subgraph(edge_groups.get("", []), "") - - # Add remaining subgraphs - for prefix in edge_groups: - if ":" in prefix or prefix == "": - continue - add_subgraph(edge_groups[prefix], prefix) - - # Add custom styles for nodes - if with_styles: - mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles()) - return mermaid_graph - + return flowchart @property def is_dag(self) -> bool: diff --git a/src/concrete-core/concrete/utils.py b/src/concrete-core/concrete/utils.py index 5c109cf0..b633585b 100644 --- a/src/concrete-core/concrete/utils.py +++ b/src/concrete-core/concrete/utils.py @@ -1,14 +1,16 @@ -"""AI generated""" - import base64 import os +from collections import defaultdict, deque +from collections.abc import Callable from datetime import timedelta +from typing import Any, TypeVar import dotenv import jwt dotenv.load_dotenv(override=True) +# region AI generated # # These are slow-changing, so the certs are hardcoded directly here GOOGLE_OIDC_DISCOVERY = "https://accounts.google.com/.well-known/openid-configuration" # GOOGLE_OIDC_CONFIG = requests.get(GOOGLE_OIDC_DISCOVERY).json() @@ -64,3 +66,67 @@ def map_python_type_to_json_type(py_type) -> str: return type_map[py_type] else: raise ValueError(f"Unexpected Python type: {py_type}") + + +# endregion + + +NodeId = TypeVar("NodeId") +Edge = TypeVar("Edge") + + +def find_sources_and_sinks( + nodes: dict[NodeId, Any], + edges: dict[NodeId, list[Edge]], + get_neighbor: Callable[[Edge], NodeId] = lambda x: x, # type: ignore[assignment, return-value] +) -> tuple[list[NodeId], list[NodeId]]: + in_degree: defaultdict[NodeId, int] = defaultdict(int) + out_degree: dict[NodeId, int] = {node: len(edges.get(node, [])) for node in nodes} + + for node in nodes: + # For each node, check its neighbors (outgoing edges) + neighbors = {get_neighbor(edge) for edge in edges.get(node, [])} + for neighbor in neighbors: + in_degree[neighbor] += 1 + + source_nodes = [node for node in nodes if in_degree[node] == 0 and out_degree[node] > 0] + sink_nodes = [node for node in nodes if in_degree[node] > 0 and out_degree[node] == 0] + + return source_nodes, sink_nodes + + +def bfs_traversal( + edges: dict[NodeId, list[Edge]], + start_nodes: list[NodeId], + end_nodes: list[NodeId] = [], + process_node: Callable[[NodeId], Any] = print, + process_edge: Callable[[NodeId, Edge], Any] = print, + get_neighbor: Callable[[Edge], NodeId] = lambda x: x, # type: ignore[assignment, return-value] +): + end_nodes_set = set(end_nodes) + + def bfs(queue: deque[NodeId], visited: set[NodeId]) -> set[NodeId]: + if not queue: + return visited # If queue is empty, return visited nodes (end condition) + + node = queue.popleft() + if node in visited: + return bfs(queue, visited) # Skip if already visited + + # Process the current node + process_node(node) + visited = visited | {node} + + # If an end node is reached and provided, stop traversal + if node in end_nodes_set: + return visited + + for edge in edges.get(node, []): + process_edge(node, edge) + + # Enqueue unvisited neighbors + neighbors = [get_neighbor(edge) for edge in edges.get(node, [])] + unvisited_neighbors = deque(neighbor for neighbor in neighbors if neighbor not in visited) + return bfs(queue + unvisited_neighbors, visited) + + return bfs(deque(start_nodes), set()) From a0eaf734485b79ce8650bd404bc2e17cde6019c0 Mon Sep 17 00:00:00 2001 From: Dance Date: Thu, 21 Nov 2024 02:19:35 -0700 Subject: [PATCH 11/14] fixed title bug --- src/concrete-core/concrete/projects/dag_project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index 4a9ed29c..54c58c87 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -92,7 +92,7 @@ def draw_mermaid( flowchart = f"flowchart {direction}\n" if title is not None: - flowchart = flowchart + f"---\ntitle: {title}\n---\n" + flowchart = f"---\ntitle: {title}\n---\n" + flowchart remove_whitespace: Callable[[str], str] = lambda string: "".join(string.split()) get_child: Callable[[tuple[str, str, Callable]], str] = lambda edge: edge[0] From d6bc0c0aab868ed4f979a913b79ddeea17e2e33e Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 4 Dec 2024 01:58:17 -0800 Subject: [PATCH 12/14] ran fixed lint --- migrations/env.py | 3 +-- src/concrete-core/concrete/__main__.py | 3 ++- .../concrete/orchestrators/software_orchestrator.py | 3 ++- src/concrete-core/concrete/tools/aws.py | 3 +-- src/concrete-core/concrete/tools/github.py | 1 - src/concrete-core/concrete/webutils.py | 3 +-- tests/test_dependencies.py | 1 - tests/test_models.py | 5 ++--- tests/test_prompts.py | 1 - 9 files changed, 9 insertions(+), 14 deletions(-) diff --git a/migrations/env.py b/migrations/env.py index fec36318..c96f1aed 100644 --- a/migrations/env.py +++ b/migrations/env.py @@ -3,11 +3,10 @@ import dotenv from alembic import context +from concrete.clients import CLIClient from sqlalchemy import URL, engine_from_config, pool from sqlmodel import SQLModel -from concrete.clients import CLIClient - dotenv.load_dotenv(override=True) diff --git a/src/concrete-core/concrete/__main__.py b/src/concrete-core/concrete/__main__.py index 5fdac1ad..056babb7 100644 --- a/src/concrete-core/concrete/__main__.py +++ b/src/concrete-core/concrete/__main__.py @@ -1,10 +1,11 @@ import argparse import asyncio -from concrete import orchestrators from concrete.clients import CLIClient from concrete.tools.aws import AwsTool, Container +from concrete import orchestrators + try: import concrete_async # noqa diff --git a/src/concrete-core/concrete/orchestrators/software_orchestrator.py b/src/concrete-core/concrete/orchestrators/software_orchestrator.py index 07c1e584..979ee269 100644 --- a/src/concrete-core/concrete/orchestrators/software_orchestrator.py +++ b/src/concrete-core/concrete/orchestrators/software_orchestrator.py @@ -2,12 +2,13 @@ from typing import cast from uuid import UUID, uuid1, uuid4 -from concrete import prompts from concrete.clients.openai import OpenAIClient from concrete.operators import Developer, Executive, Operator from concrete.projects import SoftwareProject from concrete.state import ProjectStatus, State, StatefulMixin +from concrete import prompts + from . import Orchestrator diff --git a/src/concrete-core/concrete/tools/aws.py b/src/concrete-core/concrete/tools/aws.py index de94db91..9d2d16b9 100644 --- a/src/concrete-core/concrete/tools/aws.py +++ b/src/concrete-core/concrete/tools/aws.py @@ -7,11 +7,10 @@ from datetime import datetime, timezone from typing import Optional -from dotenv import dotenv_values - from concrete.clients import CLIClient from concrete.models.base import ConcreteModel from concrete.tools import MetaTool +from dotenv import dotenv_values class Container(ConcreteModel): diff --git a/src/concrete-core/concrete/tools/github.py b/src/concrete-core/concrete/tools/github.py index 171705a1..10c0bb1b 100644 --- a/src/concrete-core/concrete/tools/github.py +++ b/src/concrete-core/concrete/tools/github.py @@ -5,7 +5,6 @@ import zipfile import requests - from concrete.clients import CLIClient from concrete.tools import MetaTool from concrete.tools.http import HTTPTool, RestApiTool diff --git a/src/concrete-core/concrete/webutils.py b/src/concrete-core/concrete/webutils.py index 0ff9f15d..53eec5ea 100644 --- a/src/concrete-core/concrete/webutils.py +++ b/src/concrete-core/concrete/webutils.py @@ -1,11 +1,10 @@ from typing import cast +from concrete.utils import verify_jwt from fastapi import Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware -from concrete.utils import verify_jwt - class AuthMiddleware(BaseHTTPMiddleware): def __init__(self, app, exclude_paths: set[str] | None = None): diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 4428dd49..83a7590d 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -2,7 +2,6 @@ import sys import pytest - from concrete.orchestrators import SoftwareOrchestrator diff --git a/tests/test_models.py b/tests/test_models.py index 401d791e..22f52fb0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,13 +1,12 @@ import unittest from uuid import uuid4 +from concrete.models.messages import TextMessage +from concrete.operators import Developer, Executive from concrete_db.orm.models import Message as SQLModelMessage from concrete_db.orm.models import Operator as SQLModelOperator from concrete_db.orm.models import Project -from concrete.models.messages import TextMessage -from concrete.operators import Developer, Executive - class TestSQLModels(unittest.TestCase): """ diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 619ba81a..8d5df83e 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -3,7 +3,6 @@ from typing import List, Tuple import pytest - from concrete.orchestrators import Orchestrator # TODO: decide where utils go From c8aada828e3a23a8f489050a84946ec3a4cff2aa Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 4 Dec 2024 03:21:43 -0800 Subject: [PATCH 13/14] WIP adding signature to mermaid graphs --- src/concrete-core/concrete/projects/dag_project.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index 54c58c87..ac12c452 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -1,3 +1,4 @@ +import inspect from collections import defaultdict from collections.abc import AsyncGenerator from typing import Any, Callable @@ -202,4 +203,6 @@ async def execute(self, options: dict = {}) -> Any: return self.name, res def __str__(self): - return f"{type(self.operator).__name__}.{self.boost_str}(**{self.default_task_kwargs})" + signature = inspect.signature(self.bound_task) + + return f"{type(self.operator).__name__}.{self.boost_str}({signature})" From b9dea821ad22fc1f79ac3d1bf523a8c020ec379c Mon Sep 17 00:00:00 2001 From: abjjabjj Date: Wed, 4 Dec 2024 12:19:28 -0800 Subject: [PATCH 14/14] added parameters for all tasks --- .../concrete/projects/dag_project.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/concrete-core/concrete/projects/dag_project.py b/src/concrete-core/concrete/projects/dag_project.py index ac12c452..bc8c9731 100644 --- a/src/concrete-core/concrete/projects/dag_project.py +++ b/src/concrete-core/concrete/projects/dag_project.py @@ -1,6 +1,6 @@ -import inspect from collections import defaultdict from collections.abc import AsyncGenerator +from inspect import Parameter, signature from typing import Any, Callable from concrete.mermaid import FlowchartDirection @@ -203,6 +203,20 @@ async def execute(self, options: dict = {}) -> Any: return self.name, res def __str__(self): - signature = inspect.signature(self.bound_task) - - return f"{type(self.operator).__name__}.{self.boost_str}({signature})" + boost_signature = signature(getattr(self.operator.__class__, self.boost_str)) + params = [ + Parameter( + param.name, + param.kind, + default=( + param.default + if param.name not in self.default_task_kwargs + else self.default_task_kwargs[param.name] + ), + annotation=param.annotation, + ) + for param in boost_signature.parameters.values() + ] + boost_signature = boost_signature.replace(parameters=params) + param_str = ", ".join(str(param) for param in boost_signature.parameters.values()) + return f"{type(self.operator).__name__}.{self.boost_str}({param_str})"