Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/wf backend model #41

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions backend/app/core/workflow/db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from contextlib import contextmanager
from typing import Callable, TypeVar, Any
from sqlmodel import Session
from app.core.db import engine

T = TypeVar('T')

@contextmanager
def get_db_session():
session = Session(engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

def db_operation(operation: Callable[[Session], T]) -> T:
"""
执行数据库操作的辅助函数。

:param operation: 一个接受 Session 作为参数并返回结果的函数。
:return: 操作的结果。
"""
with get_db_session() as session:
return operation(session)

# 示例用法
def get_all_models_helper():
from app.curd.models import get_all_models
return db_operation(get_all_models)

def get_models_by_provider_helper(provider_id: int):
from app.curd.models import get_models_by_provider
return db_operation(lambda session: get_models_by_provider(session, provider_id))

# 可以根据需要添加更多辅助函数
44 changes: 34 additions & 10 deletions backend/app/core/workflow/init_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from typing import Dict, Any, Set
from typing import Dict, Any, Set
from functools import lru_cache
import time
from langgraph.graph.graph import CompiledGraph
Expand All @@ -10,6 +10,10 @@
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langchain_core.runnables import RunnableLambda

from app.curd.models import get_all_models
from app.api.deps import SessionDep
from app.core.workflow.db_utils import get_all_models_helper
from .node import (
WorkerNode,
SequentialWorkerNode,
Expand Down Expand Up @@ -59,7 +63,9 @@ def should_continue(state: TeamState) -> str:


def initialize_graph(
build_config: Dict[str, Any], checkpointer: BaseCheckpointSaver,save_graph_img=False
build_config: Dict[str, Any],
checkpointer: BaseCheckpointSaver,
save_graph_img=False,
) -> CompiledGraph:
global tool_name_to_node_id

Expand Down Expand Up @@ -120,13 +126,31 @@ def initialize_graph(
node_data = node["data"]

if node_type == "llm":
model_name = node_data["model"]
all_models = get_all_models_helper()
model_info = None
for model in all_models.data:
if model.ai_model_name == model_name:
model_info = {
"ai_model_name": model.ai_model_name,
"provider_name": model.provider.provider_name,
"base_url": model.provider.base_url,
"api_key": model.provider.api_key,
}
break
if model_info is None:
raise ValueError(f"Model {model_name} not supported now.")
# in the future wo can use more langchain templates here apply to different node type TODO
if is_sequential:
node_class = SequentialWorkerNode
# node_class = SequentialWorkerNode
node_class = LLMNode
elif is_hierarchical:
if llm_children[node_id]: # If the node has child LLM nodes
node_class = LeaderNode
if llm_children[node_id]:
# node_class = LeaderNode
node_class = LLMNode
else:
node_class = WorkerNode
# node_class = WorkerNode
node_class = LLMNode
else:
node_class = LLMNode

Expand All @@ -149,11 +173,11 @@ def initialize_graph(
node_id,
RunnableLambda(
node_class(
provider=node_data.get("provider", "zhipuai"),
model=node_data["model"],
provider=model_info["provider_name"],
model=model_info["ai_model_name"],
tools=tools_to_bind,
openai_api_key="",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
openai_api_key=model_info["api_key"],
openai_api_base=model_info["base_url"],
temperature=node_data["temperature"],
).work
),
Expand Down
53 changes: 9 additions & 44 deletions backend/app/core/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,43 +158,24 @@ def __init__(
self,
provider: str,
model: str,
tools:Sequence[BaseTool],
tools: Sequence[BaseTool],
openai_api_key: str,
openai_api_base: str,
temperature: float,
):

if provider in ["zhipuai"] and openai_api_base:

# self.model = ChatOpenAI(
# model=model,
# streaming=True,
# openai_api_key=openai_api_key,
# openai_api_base=openai_api_base,
# temperature=temperature,
# )
# self.final_answer_model = ChatOpenAI(
# model=model,
# streaming=True,
# openai_api_key=openai_api_key,
# openai_api_base=openai_api_base,
# temperature=0,
# )

if provider in ["zhipuai", "Siliconflow"]:
self.model = ChatOpenAI(
# model="chatglm_turbo",
model="glm-4-flash",
temperature=0.01,
# openai_api_key='9953866f9b7fac2fd6d564842d8bcc79.AbXduj53KA3SDSMs',
# openai_api_key='fe1f2097b7284bd4baa1284be8d54aea.6VOvX4efbye8M6m0',
openai_api_key="1a65e1fed7ab7a788ee94d73570e9fcf.5FVs3ceE6POvEnSN",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
model=model,
temperature=temperature,
openai_api_key=openai_api_key,
openai_api_base=openai_api_base,
)
if len(tools)>=1:
if len(tools) >= 1:
self.model = self.model.bind_tools(tools)
self.final_answer_model = self.model

elif provider in ["openai"] and openai_api_base:
elif provider in ["openai"]:
self.model = init_chat_model(
model,
model_provider=provider,
Expand Down Expand Up @@ -362,8 +343,7 @@ class LLMNode(BaseNode):
"system",
(
"Perform the task given to you.\n"
"If you are unable to perform the task, that's OK, another member with different tools "
"will help where you left off. Do not attempt to communicate with other members. "
"If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task."
"Execute what you can to make progress. "
"Stay true to your role and use your tools if necessary.\n\n"
),
Expand All @@ -376,21 +356,6 @@ class LLMNode(BaseNode):
]
)

def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage:
"""Tag a name to the AI message"""
ai_message.name = name
return ai_message

def get_next_member_in_sequence(
self, members: Mapping[str, GraphMember | GraphLeader], current_name: str
) -> str | None:
member_names = list(members.keys())
next_index = member_names.index(current_name) + 1
if next_index < len(members):
return member_names[member_names.index(current_name) + 1]
else:
return None

async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamState:
history = state.get("history", [])
messages = state.get("messages", [])
Expand Down
4 changes: 2 additions & 2 deletions web/src/components/Teams/WorkflowTeamSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ function WorkflowTeamSettings({ teamId, triggerSubmit }: WorkflowSettingProps) {
{
id: "end",
type: "end",
position: { x: 891.4025316455695, y: 221.5569620253164 },
position: { x: 891, y: 221 },
data: { label: "End" },
},
{
id: "llm",
type: "llm",
position: { x: 500.04430379746833, y: 219.95189873417723 },
position: { x: 500, y: 219 },
data: { label: "LLM", model: "glm-4-flash", temperature: 0.1 },
},
],
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/WorkFlow/nodes/Tool/Properties.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const ToolNodeProperties: React.FC<ToolNodePropertiesProps> = ({
<Box>
<Text fontWeight="bold">Tool:</Text>
<Select
value={node.data.tool}
value={node.data.tools}
onChange={(e) => onNodeDataChange(node.id, "tool", e.target.value)}
>
<option value="calculator">Calculator</option>
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/WorkFlow/nodes/Tool/ToolNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const ToolNode: React.FC<NodeProps> = (props) => {
<Handle type="source" position={Position.Left} id="left" />
<Handle type="source" position={Position.Right} id="right" />
<Box bg="#f2f4f7" borderRadius={"md"} w="full" p="2">
<Text fontSize="xs">{props.data.tool}</Text>
<Text fontSize="xs">{props.data.tools}</Text>
</Box>
</BaseNode>
);
Expand Down
Loading