diff --git a/backend/app/core/workflow/db_utils.py b/backend/app/core/workflow/db_utils.py new file mode 100644 index 00000000..2f4192c3 --- /dev/null +++ b/backend/app/core/workflow/db_utils.py @@ -0,0 +1,39 @@ +from contextlib import contextmanager +from typing import Callable, TypeVar, Any +from sqlmodel import Session +from app.core.db import engine + +T = TypeVar('T') + +@contextmanager +def get_db_session(): + session = Session(engine) + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + +def db_operation(operation: Callable[[Session], T]) -> T: + """ + 执行数据库操作的辅助函数。 + + :param operation: 一个接受 Session 作为参数并返回结果的函数。 + :return: 操作的结果。 + """ + with get_db_session() as session: + return operation(session) + +# 示例用法 +def get_all_models_helper(): + from app.curd.models import get_all_models + return db_operation(get_all_models) + +def get_models_by_provider_helper(provider_id: int): + from app.curd.models import get_models_by_provider + return db_operation(lambda session: get_models_by_provider(session, provider_id)) + +# 可以根据需要添加更多辅助函数 \ No newline at end of file diff --git a/backend/app/core/workflow/init_graph.py b/backend/app/core/workflow/init_graph.py index 6ae619fb..eaf2151c 100644 --- a/backend/app/core/workflow/init_graph.py +++ b/backend/app/core/workflow/init_graph.py @@ -1,6 +1,6 @@ from langchain.pydantic_v1 import BaseModel from langchain.tools import BaseTool -from typing import Dict, Any, Set +from typing import Dict, Any, Set from functools import lru_cache import time from langgraph.graph.graph import CompiledGraph @@ -10,6 +10,10 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph import END, StateGraph from langchain_core.runnables import RunnableLambda + +from app.curd.models import get_all_models +from app.api.deps import SessionDep +from app.core.workflow.db_utils import get_all_models_helper from .node import ( WorkerNode, SequentialWorkerNode, @@ -59,7 +63,9 @@ def should_continue(state: TeamState) -> str: def initialize_graph( - build_config: Dict[str, Any], checkpointer: BaseCheckpointSaver,save_graph_img=False + build_config: Dict[str, Any], + checkpointer: BaseCheckpointSaver, + save_graph_img=False, ) -> CompiledGraph: global tool_name_to_node_id @@ -120,13 +126,31 @@ def initialize_graph( node_data = node["data"] if node_type == "llm": + model_name = node_data["model"] + all_models = get_all_models_helper() + model_info = None + for model in all_models.data: + if model.ai_model_name == model_name: + model_info = { + "ai_model_name": model.ai_model_name, + "provider_name": model.provider.provider_name, + "base_url": model.provider.base_url, + "api_key": model.provider.api_key, + } + break + if model_info is None: + raise ValueError(f"Model {model_name} not supported now.") + # in the future wo can use more langchain templates here apply to different node type TODO if is_sequential: - node_class = SequentialWorkerNode + # node_class = SequentialWorkerNode + node_class = LLMNode elif is_hierarchical: - if llm_children[node_id]: # If the node has child LLM nodes - node_class = LeaderNode + if llm_children[node_id]: + # node_class = LeaderNode + node_class = LLMNode else: - node_class = WorkerNode + # node_class = WorkerNode + node_class = LLMNode else: node_class = LLMNode @@ -149,11 +173,11 @@ def initialize_graph( node_id, RunnableLambda( node_class( - provider=node_data.get("provider", "zhipuai"), - model=node_data["model"], + provider=model_info["provider_name"], + model=model_info["ai_model_name"], tools=tools_to_bind, - openai_api_key="", - openai_api_base="https://open.bigmodel.cn/api/paas/v4/", + openai_api_key=model_info["api_key"], + openai_api_base=model_info["base_url"], temperature=node_data["temperature"], ).work ), diff --git a/backend/app/core/workflow/node.py b/backend/app/core/workflow/node.py index 590c4b14..ddfbd674 100644 --- a/backend/app/core/workflow/node.py +++ b/backend/app/core/workflow/node.py @@ -158,43 +158,24 @@ def __init__( self, provider: str, model: str, - tools:Sequence[BaseTool], + tools: Sequence[BaseTool], openai_api_key: str, openai_api_base: str, temperature: float, ): - if provider in ["zhipuai"] and openai_api_base: - - # self.model = ChatOpenAI( - # model=model, - # streaming=True, - # openai_api_key=openai_api_key, - # openai_api_base=openai_api_base, - # temperature=temperature, - # ) - # self.final_answer_model = ChatOpenAI( - # model=model, - # streaming=True, - # openai_api_key=openai_api_key, - # openai_api_base=openai_api_base, - # temperature=0, - # ) - + if provider in ["zhipuai", "Siliconflow"]: self.model = ChatOpenAI( - # model="chatglm_turbo", - model="glm-4-flash", - temperature=0.01, - # openai_api_key='9953866f9b7fac2fd6d564842d8bcc79.AbXduj53KA3SDSMs', - # openai_api_key='fe1f2097b7284bd4baa1284be8d54aea.6VOvX4efbye8M6m0', - openai_api_key="1a65e1fed7ab7a788ee94d73570e9fcf.5FVs3ceE6POvEnSN", - openai_api_base="https://open.bigmodel.cn/api/paas/v4/", + model=model, + temperature=temperature, + openai_api_key=openai_api_key, + openai_api_base=openai_api_base, ) - if len(tools)>=1: + if len(tools) >= 1: self.model = self.model.bind_tools(tools) self.final_answer_model = self.model - elif provider in ["openai"] and openai_api_base: + elif provider in ["openai"]: self.model = init_chat_model( model, model_provider=provider, @@ -362,8 +343,7 @@ class LLMNode(BaseNode): "system", ( "Perform the task given to you.\n" - "If you are unable to perform the task, that's OK, another member with different tools " - "will help where you left off. Do not attempt to communicate with other members. " + "If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task." "Execute what you can to make progress. " "Stay true to your role and use your tools if necessary.\n\n" ), @@ -376,21 +356,6 @@ class LLMNode(BaseNode): ] ) - def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage: - """Tag a name to the AI message""" - ai_message.name = name - return ai_message - - def get_next_member_in_sequence( - self, members: Mapping[str, GraphMember | GraphLeader], current_name: str - ) -> str | None: - member_names = list(members.keys()) - next_index = member_names.index(current_name) + 1 - if next_index < len(members): - return member_names[member_names.index(current_name) + 1] - else: - return None - async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamState: history = state.get("history", []) messages = state.get("messages", []) diff --git a/web/src/components/Teams/WorkflowTeamSettings.tsx b/web/src/components/Teams/WorkflowTeamSettings.tsx index 15a4dd92..d6f6a3fc 100644 --- a/web/src/components/Teams/WorkflowTeamSettings.tsx +++ b/web/src/components/Teams/WorkflowTeamSettings.tsx @@ -58,13 +58,13 @@ function WorkflowTeamSettings({ teamId, triggerSubmit }: WorkflowSettingProps) { { id: "end", type: "end", - position: { x: 891.4025316455695, y: 221.5569620253164 }, + position: { x: 891, y: 221 }, data: { label: "End" }, }, { id: "llm", type: "llm", - position: { x: 500.04430379746833, y: 219.95189873417723 }, + position: { x: 500, y: 219 }, data: { label: "LLM", model: "glm-4-flash", temperature: 0.1 }, }, ], diff --git a/web/src/components/WorkFlow/nodes/Tool/Properties.tsx b/web/src/components/WorkFlow/nodes/Tool/Properties.tsx index 5cdccdea..e49644b0 100644 --- a/web/src/components/WorkFlow/nodes/Tool/Properties.tsx +++ b/web/src/components/WorkFlow/nodes/Tool/Properties.tsx @@ -15,7 +15,7 @@ const ToolNodeProperties: React.FC = ({ Tool: