Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draw mermaid flowchart Version 1 #190

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
51c8114
refactored DAGNode and Project to use a dictionary for nodes instead …
abjjabjj Nov 8, 2024
0e4caab
Merge branch 'main' into dance/project_api
abjjabjj Nov 8, 2024
3d277e3
WIP DB integration
abjjabjj Nov 13, 2024
052cec9
testless, first iteration of DAGProject API
abjjabjj Nov 13, 2024
6e91398
testless, first iteration of DAGProject API
abjjabjj Nov 13, 2024
df90fd1
Merge branch 'main' into dance/project_api
abjjabjj Nov 13, 2024
fc50f4d
used JSON columns for default kwargs and options
abjjabjj Nov 15, 2024
dea1b1e
Merge branch 'main' into dance/project_api
abjjabjj Nov 15, 2024
7a9fe90
Merge branch 'main' into dance/project_api
abjjabjj Nov 15, 2024
fbf179e
Merge branch 'main' into dance/project_api
abjjabjj Nov 15, 2024
a54cb5b
session refactor and bootstrap for webapp tests
abjjabjj Nov 19, 2024
094c267
bootstrap draw_mermaid
abjjabjj Nov 19, 2024
bb23838
address michael comments
abjjabjj Nov 19, 2024
d107e38
fixed DB models, added hard coded transformation on project api
abjjabjj Nov 20, 2024
1e829bd
Merge branch 'main' into dance/project_api
abjjabjj Nov 20, 2024
82ed3ab
Merge branch 'main' into dance/draw_mermaid_flowchart
abjjabjj Nov 21, 2024
34b2ddb
Merge branch 'dance/project_api' into dance/draw_mermaid_flowchart
abjjabjj Nov 21, 2024
bb4c42d
first iteration on mermaid output
abjjabjj Nov 21, 2024
a0eaf73
fixed title bug
abjjabjj Nov 21, 2024
877f293
Merge branch 'main' into dance/draw_mermaid_flowchart
abjjabjj Dec 4, 2024
d6bc0c0
ran fixed lint
abjjabjj Dec 4, 2024
c8aada8
WIP adding signature to mermaid graphs
abjjabjj Dec 4, 2024
b9dea82
added parameters for all tasks
abjjabjj Dec 4, 2024
f68430d
Merge branch 'main' into dance/draw_mermaid_flowchart
abjjabjj Dec 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

import dotenv
from alembic import context
from concrete.clients import CLIClient
from sqlalchemy import URL, engine_from_config, pool
from sqlmodel import SQLModel

from concrete.clients import CLIClient

dotenv.load_dotenv(override=True)


Expand Down
75 changes: 75 additions & 0 deletions migrations/versions/fc83e32e33f5_create_dag_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""create dag tables

Revision ID: fc83e32e33f5
Revises: 3bb0633b746d
Create Date: 2024-11-20 13:47:29.927238

"""

from typing import Sequence, Union

import sqlalchemy as sa
import sqlmodel
from alembic import op

# revision identifiers, used by Alembic.
revision: str = 'fc83e32e33f5'
down_revision: Union[str, None] = '3bb0633b746d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
'dagproject',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('modified_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name'),
)
op.create_table(
'dagnode',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('modified_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
sa.Column('project_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column('operator_name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column('task_name', sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column('default_task_kwargs', sa.JSON(), nullable=True),
sa.Column('options', sa.JSON(), nullable=True),
sa.ForeignKeyConstraint(['project_name'], ['dagproject.name'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'project_name', name='no_duplicate_names_per_project'),
)
op.create_table(
'dagnodetodagnodelink',
sa.Column('project_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('parent_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('child_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('input_to_child', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(
['project_name', 'child_name'], ['dagnode.project_name', 'dagnode.name'], ondelete='CASCADE'
),
sa.ForeignKeyConstraint(
['project_name', 'parent_name'], ['dagnode.project_name', 'dagnode.name'], ondelete='CASCADE'
),
sa.ForeignKeyConstraint(['project_name'], ['dagproject.name'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('project_name', 'parent_name', 'child_name'),
)
op.create_index(
op.f('ix_dagnodetodagnodelink_project_name'), 'dagnodetodagnodelink', ['project_name'], unique=False
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_dagnodetodagnodelink_project_name'), table_name='dagnodetodagnodelink')
op.drop_table('dagnodetodagnodelink')
op.drop_table('dagnode')
op.drop_table('dagproject')
# ### end Alembic commands ###
22 changes: 11 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
dev = [
"jupyterlab",
"black",
"flake8",
"isort",
"bandit",
"pre-commit",
"mypy",
"alembic",
"pytest",
"ipykernel",
"boto3",
"boto3-stubs[ecs]"
"flake8",
"isort",
"bandit",
"pre-commit",
"mypy",
"alembic",
"pytest",
"ipykernel",
"boto3",
"boto3-stubs[ecs]",
"fastapi[standard]>=0.115.4",
]

packages = [
"concrete-core",
"concrete-async",
Expand Down
2 changes: 1 addition & 1 deletion src/concrete-async/concrete_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import Any, cast

from celery.result import AsyncResult
from concrete.clients import CLIClient, model_to_schema
from concrete_async.tasks import abstract_operation

import concrete
from concrete.clients import CLIClient, model_to_schema
from concrete.models import KombuMixin, Message, Operation


Expand Down
4 changes: 2 additions & 2 deletions src/concrete-core/concrete/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dotenv import load_dotenv

from . import operators, orchestrators
from . import abstract, models, operators, orchestrators

# Always runs even when importing submodules
# https://stackoverflow.com/a/27144933
load_dotenv(override=True)
__all__ = ["operators", "orchestrators"]
__all__ = ["abstract", "models", "operators", "orchestrators"]
3 changes: 2 additions & 1 deletion src/concrete-core/concrete/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
import asyncio

from concrete import orchestrators
from concrete.clients import CLIClient
from concrete.tools.aws import AwsTool, Container

from concrete import orchestrators

try:
import concrete_async # noqa

Expand Down
8 changes: 8 additions & 0 deletions src/concrete-core/concrete/mermaid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from enum import StrEnum


class FlowchartDirection(StrEnum):
LEFT_RIGHT = "LR"
RIGHT_LEFT = "RL"
TOP_DOWN = "TD"
BOTTOM_UP = "BT"
3 changes: 3 additions & 0 deletions src/concrete-core/concrete/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import base

__all__ = ["base"]
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import cast
from uuid import UUID, uuid1, uuid4

from concrete import prompts
from concrete.clients.openai import OpenAIClient
from concrete.operators import Developer, Executive, Operator
from concrete.projects import SoftwareProject
from concrete.state import ProjectStatus, State, StatefulMixin

from concrete import prompts

from . import Orchestrator


Expand Down
2 changes: 0 additions & 2 deletions src/concrete-core/concrete/projects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .dag_project import DAGNode, Project
from .software_project import SoftwareProject

PROJECTS: dict[str, Project] = {}

__all__ = ["DAGNode", "Project", "SoftwareProject"]
115 changes: 97 additions & 18 deletions src/concrete-core/concrete/projects/dag_project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import defaultdict
from collections.abc import AsyncGenerator
from inspect import Parameter, signature
from typing import Any, Callable

from concrete.mermaid import FlowchartDirection
from concrete.operators import Operator
from concrete.state import StatefulMixin
from concrete.utils import bfs_traversal, find_sources_and_sinks


class Project(StatefulMixin):
Expand All @@ -16,18 +19,18 @@ def __init__(
self,
options: dict = {},
) -> None:
self.edges: dict[DAGNode, list[tuple[DAGNode, str, Callable]]] = defaultdict(list)
self.edges: dict[str, list[tuple[str, str, Callable]]] = defaultdict(list)
self.options = options

self.nodes: set[DAGNode] = set()
self.nodes: dict[str, DAGNode] = {}

def add_edge(
self,
child: "DAGNode",
parent: "DAGNode",
parent: str,
child: str,
res_name: str,
res_transformation: Callable = lambda x: x,
) -> None:
) -> tuple[str, str, str]:
"""
child: Downstream node
parent: Upstream node
Expand All @@ -40,8 +43,11 @@ def add_edge(

self.edges[parent].append((child, res_name, res_transformation))

def add_node(self, node: "DAGNode") -> None:
self.nodes.add(node)
return (parent, child, res_name)

def add_node(self, node: "DAGNode") -> "DAGNode":
self.nodes[node.name] = node
return node

async def execute(self) -> AsyncGenerator[tuple[str, str], None]:
if not self.is_dag:
Expand All @@ -55,28 +61,82 @@ async def execute(self) -> AsyncGenerator[tuple[str, str], None]:

while no_dep_nodes:
ready_node = no_dep_nodes.pop()
operator_name, res = await ready_node.execute(self.options)
operator_name, res = await self.nodes[ready_node].execute(self.options)

yield (operator_name, res)

for child, res_name, res_transformation in self.edges[ready_node]:
child.update(res_transformation(res), res_name)
self.nodes[child].update(res_name, res_transformation(res))
node_dep_count[child] -= 1
if node_dep_count[child] == 0:
no_dep_nodes.add(child)

def draw_mermaid(
self,
title: str | None = None,
direction: FlowchartDirection = FlowchartDirection.TOP_DOWN,
start_nodes: list[str] = [],
end_nodes: list[str] = [],
) -> str:
"""Draws a Mermaid flowchart from the DAG.

Args:
title (str, optional): Title of the flowchart. Defaults to None.
direction (FlowchartDirection, optional):
Direction of the flowchart, i.e. start and end positions. Defaults to top down.
start_nodes (list[str], optional): Names of the source (i.e. start) nodes. Defaults to project source nodes.
end_nodes (list[str], optional): Names of the sink (i.e. end) nodes. Defaults to project sink nodes.

Returns:
str: Mermaid flowchart syntax.
"""
flowchart = f"flowchart {direction}\n"

if title is not None:
flowchart = f"---\ntitle: {title}\n---\n" + flowchart

remove_whitespace: Callable[[str], str] = lambda string: "".join(string.split())
get_child: Callable[[tuple[str, str, Callable]], str] = lambda edge: edge[0]

def process_node(node: str) -> None:
nonlocal flowchart
flowchart = flowchart + f"\t{remove_whitespace(node)}([\"{self.nodes[node]!s}\"])\n"

def process_edge(node: str, edge: tuple[str, str, Callable]) -> None:
# TODO: design a good string representation for result transformation
nonlocal flowchart
flowchart = flowchart + f"\t{remove_whitespace(node)} -->|{edge[1]}| {remove_whitespace(edge[0])}\n"

if not start_nodes or not end_nodes:
sources, sinks = find_sources_and_sinks(self.nodes, self.edges, get_child)
if not start_nodes:
start_nodes = sources
if not end_nodes:
end_nodes = sinks

bfs_traversal(
self.edges,
start_nodes,
end_nodes,
process_node=process_node,
process_edge=process_edge,
get_neighbor=get_child,
)

return flowchart

@property
def is_dag(self):
def is_dag(self) -> bool:
# AI generated
visited = set()
rec_stack = set()
visited: set[str] = set()
rec_stack: set[str] = set()

def dfs(node: DAGNode) -> bool:
def dfs(node: str) -> bool:
if node not in visited:
visited.add(node)
rec_stack.add(node)

for child, _, _ in self.edges.get(node, []):
for child, _, _ in self.edges[node]:
if child not in visited:
if not dfs(child):
return False
Expand All @@ -102,6 +162,7 @@ class DAGNode:

def __init__(
self,
name: str,
task: str,
operator: Operator,
default_task_kwargs: dict[str, Any] = {},
Expand All @@ -119,12 +180,13 @@ def __init__(
raise ValueError(f"{operator} does not have a method {task}")
self.operator: Operator = operator

self.task_str = task
self.name = name
self.boost_str = task
self.dynamic_kwargs: dict[str, Any] = {}
self.default_task_kwargs = default_task_kwargs # TODO probably want to manage this in the project
self.options = options # Could also throw this into default_task_kwargs

def update(self, dyn_kwarg_value, dyn_kwarg_name) -> None:
def update(self, dyn_kwarg_name, dyn_kwarg_value) -> None:
self.dynamic_kwargs[dyn_kwarg_name] = dyn_kwarg_value

async def execute(self, options: dict = {}) -> Any:
Expand All @@ -133,11 +195,28 @@ async def execute(self, options: dict = {}) -> Any:
"""
kwargs = self.default_task_kwargs | self.dynamic_kwargs
options = self.options | options
print(kwargs)
res = self.bound_task(**kwargs, options=self.options | options)
if options.get("run_async"):
res = res.get().message

return type(self.operator).__name__, res
return self.name, res

def __str__(self):
return f"{type(self.operator).__name__}.{self.task_str}(**{self.default_task_kwargs})"
boost_signature = signature(getattr(self.operator.__class__, self.boost_str))
params = [
Parameter(
param.name,
param.kind,
default=(
param.default
if param.name not in self.default_task_kwargs
else self.default_task_kwargs[param.name]
),
annotation=param.annotation,
)
for param in boost_signature.parameters.values()
]
boost_signature = boost_signature.replace(parameters=params)
param_str = ", ".join(str(param) for param in boost_signature.parameters.values())
return f"{type(self.operator).__name__}.{self.boost_str}({param_str})"
3 changes: 1 addition & 2 deletions src/concrete-core/concrete/tools/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from datetime import datetime, timezone
from typing import Optional

from dotenv import dotenv_values

from concrete.clients import CLIClient
from concrete.models.base import ConcreteModel
from concrete.tools import MetaTool
from dotenv import dotenv_values


class Container(ConcreteModel):
Expand Down
Loading
Loading