Skip to content

Commit

Permalink
Feat/wf backend model (#41)
Browse files Browse the repository at this point in the history
* feat: refactor backend node code

* feat: add db_utils and now get model info from database

* feat: refactor basenode code with a better init model according to provider and model name

* fix: change tool node data tool to tools
  • Loading branch information
Onelevenvy authored Sep 24, 2024
1 parent 64c40ed commit 52eaddd
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 58 deletions.
39 changes: 39 additions & 0 deletions backend/app/core/workflow/db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from contextlib import contextmanager
from typing import Callable, TypeVar, Any
from sqlmodel import Session
from app.core.db import engine

T = TypeVar('T')

@contextmanager
def get_db_session():
session = Session(engine)
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

def db_operation(operation: Callable[[Session], T]) -> T:
"""
执行数据库操作的辅助函数。
:param operation: 一个接受 Session 作为参数并返回结果的函数。
:return: 操作的结果。
"""
with get_db_session() as session:
return operation(session)

# 示例用法
def get_all_models_helper():
from app.curd.models import get_all_models
return db_operation(get_all_models)

def get_models_by_provider_helper(provider_id: int):
from app.curd.models import get_models_by_provider
return db_operation(lambda session: get_models_by_provider(session, provider_id))

# 可以根据需要添加更多辅助函数
44 changes: 34 additions & 10 deletions backend/app/core/workflow/init_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain.pydantic_v1 import BaseModel
from langchain.tools import BaseTool
from typing import Dict, Any, Set
from typing import Dict, Any, Set
from functools import lru_cache
import time
from langgraph.graph.graph import CompiledGraph
Expand All @@ -10,6 +10,10 @@
from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.graph import END, StateGraph
from langchain_core.runnables import RunnableLambda

from app.curd.models import get_all_models
from app.api.deps import SessionDep
from app.core.workflow.db_utils import get_all_models_helper
from .node import (
WorkerNode,
SequentialWorkerNode,
Expand Down Expand Up @@ -59,7 +63,9 @@ def should_continue(state: TeamState) -> str:


def initialize_graph(
build_config: Dict[str, Any], checkpointer: BaseCheckpointSaver,save_graph_img=False
build_config: Dict[str, Any],
checkpointer: BaseCheckpointSaver,
save_graph_img=False,
) -> CompiledGraph:
global tool_name_to_node_id

Expand Down Expand Up @@ -120,13 +126,31 @@ def initialize_graph(
node_data = node["data"]

if node_type == "llm":
model_name = node_data["model"]
all_models = get_all_models_helper()
model_info = None
for model in all_models.data:
if model.ai_model_name == model_name:
model_info = {
"ai_model_name": model.ai_model_name,
"provider_name": model.provider.provider_name,
"base_url": model.provider.base_url,
"api_key": model.provider.api_key,
}
break
if model_info is None:
raise ValueError(f"Model {model_name} not supported now.")
# in the future wo can use more langchain templates here apply to different node type TODO
if is_sequential:
node_class = SequentialWorkerNode
# node_class = SequentialWorkerNode
node_class = LLMNode
elif is_hierarchical:
if llm_children[node_id]: # If the node has child LLM nodes
node_class = LeaderNode
if llm_children[node_id]:
# node_class = LeaderNode
node_class = LLMNode
else:
node_class = WorkerNode
# node_class = WorkerNode
node_class = LLMNode
else:
node_class = LLMNode

Expand All @@ -149,11 +173,11 @@ def initialize_graph(
node_id,
RunnableLambda(
node_class(
provider=node_data.get("provider", "zhipuai"),
model=node_data["model"],
provider=model_info["provider_name"],
model=model_info["ai_model_name"],
tools=tools_to_bind,
openai_api_key="",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
openai_api_key=model_info["api_key"],
openai_api_base=model_info["base_url"],
temperature=node_data["temperature"],
).work
),
Expand Down
53 changes: 9 additions & 44 deletions backend/app/core/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,43 +158,24 @@ def __init__(
self,
provider: str,
model: str,
tools:Sequence[BaseTool],
tools: Sequence[BaseTool],
openai_api_key: str,
openai_api_base: str,
temperature: float,
):

if provider in ["zhipuai"] and openai_api_base:

# self.model = ChatOpenAI(
# model=model,
# streaming=True,
# openai_api_key=openai_api_key,
# openai_api_base=openai_api_base,
# temperature=temperature,
# )
# self.final_answer_model = ChatOpenAI(
# model=model,
# streaming=True,
# openai_api_key=openai_api_key,
# openai_api_base=openai_api_base,
# temperature=0,
# )

if provider in ["zhipuai", "Siliconflow"]:
self.model = ChatOpenAI(
# model="chatglm_turbo",
model="glm-4-flash",
temperature=0.01,
# openai_api_key='9953866f9b7fac2fd6d564842d8bcc79.AbXduj53KA3SDSMs',
# openai_api_key='fe1f2097b7284bd4baa1284be8d54aea.6VOvX4efbye8M6m0',
openai_api_key="1a65e1fed7ab7a788ee94d73570e9fcf.5FVs3ceE6POvEnSN",
openai_api_base="https://open.bigmodel.cn/api/paas/v4/",
model=model,
temperature=temperature,
openai_api_key=openai_api_key,
openai_api_base=openai_api_base,
)
if len(tools)>=1:
if len(tools) >= 1:
self.model = self.model.bind_tools(tools)
self.final_answer_model = self.model

elif provider in ["openai"] and openai_api_base:
elif provider in ["openai"]:
self.model = init_chat_model(
model,
model_provider=provider,
Expand Down Expand Up @@ -362,8 +343,7 @@ class LLMNode(BaseNode):
"system",
(
"Perform the task given to you.\n"
"If you are unable to perform the task, that's OK, another member with different tools "
"will help where you left off. Do not attempt to communicate with other members. "
"If you are unable to perform the task, that's OK, you can ask human for help, or just say that you are unable to perform the task."
"Execute what you can to make progress. "
"Stay true to your role and use your tools if necessary.\n\n"
),
Expand All @@ -376,21 +356,6 @@ class LLMNode(BaseNode):
]
)

def tag_with_name(self, ai_message: AIMessage, name: str) -> AIMessage:
"""Tag a name to the AI message"""
ai_message.name = name
return ai_message

def get_next_member_in_sequence(
self, members: Mapping[str, GraphMember | GraphLeader], current_name: str
) -> str | None:
member_names = list(members.keys())
next_index = member_names.index(current_name) + 1
if next_index < len(members):
return member_names[member_names.index(current_name) + 1]
else:
return None

async def work(self, state: TeamState, config: RunnableConfig) -> ReturnTeamState:
history = state.get("history", [])
messages = state.get("messages", [])
Expand Down
4 changes: 2 additions & 2 deletions web/src/components/Teams/WorkflowTeamSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ function WorkflowTeamSettings({ teamId, triggerSubmit }: WorkflowSettingProps) {
{
id: "end",
type: "end",
position: { x: 891.4025316455695, y: 221.5569620253164 },
position: { x: 891, y: 221 },
data: { label: "End" },
},
{
id: "llm",
type: "llm",
position: { x: 500.04430379746833, y: 219.95189873417723 },
position: { x: 500, y: 219 },
data: { label: "LLM", model: "glm-4-flash", temperature: 0.1 },
},
],
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/WorkFlow/nodes/Tool/Properties.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ const ToolNodeProperties: React.FC<ToolNodePropertiesProps> = ({
<Box>
<Text fontWeight="bold">Tool:</Text>
<Select
value={node.data.tool}
value={node.data.tools}
onChange={(e) => onNodeDataChange(node.id, "tool", e.target.value)}
>
<option value="calculator">Calculator</option>
Expand Down
2 changes: 1 addition & 1 deletion web/src/components/WorkFlow/nodes/Tool/ToolNode.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const ToolNode: React.FC<NodeProps> = (props) => {
<Handle type="source" position={Position.Left} id="left" />
<Handle type="source" position={Position.Right} id="right" />
<Box bg="#f2f4f7" borderRadius={"md"} w="full" p="2">
<Text fontSize="xs">{props.data.tool}</Text>
<Text fontSize="xs">{props.data.tools}</Text>
</Box>
</BaseNode>
);
Expand Down

0 comments on commit 52eaddd

Please sign in to comment.