From 4993b9b362281a8375d58fa3b9c5f52eeb3dab33 Mon Sep 17 00:00:00 2001
From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com>
Date: Tue, 1 Oct 2024 11:37:49 +0800
Subject: [PATCH] feat:add img under standing tool and refactor code to add
toolkits support (#57)
* fix:rename model curd function name to solve the confilct between curd and router
* feat:add img understanding tools
* feat:support toolkits
* fix:img understanding icon
---
.env.example | 4 +-
backend/app/api/routes/providermodel.py | 12 +--
backend/app/core/db.py | 12 ++-
backend/app/core/tools/ask_human/__init__.py | 3 +
backend/app/core/tools/google/__init__.py | 3 +
.../googletranslate.py | 0
backend/app/core/tools/math/__init__.py | 3 +
.../app/core/tools/openweather/__init__.py | 3 +
.../app/core/tools/siliconflow/__init__.py | 3 +
...{siliconflow.py => siliconflow_img_gen.py} | 2 +-
backend/app/core/tools/spark/__init__.py | 3 +
.../spark/{spark.py => spark_img_gen.py} | 2 +-
backend/app/core/tools/tool.py | 82 +++++++++++++++++++
backend/app/core/tools/tool_manager.py | 29 ++++---
backend/app/core/tools/zhipuai/__init__.py | 3 +
backend/app/core/tools/zhipuai/img_4v.py | 45 ++++++++++
backend/app/curd/models.py | 6 +-
backend/pyproject.toml | 3 +-
.../app/(applayout)/teams/[teamId]/page.tsx | 8 +-
web/src/app/(applayout)/tools/page.tsx | 15 ++--
web/src/components/Icons/Tools/index.tsx | 3 +-
.../components/Teams/DebugPreview/index.tsx | 1 +
.../components/Teams/NormalTeamSettings.tsx | 3 +-
23 files changed, 208 insertions(+), 40 deletions(-)
create mode 100644 backend/app/core/tools/ask_human/__init__.py
create mode 100644 backend/app/core/tools/google/__init__.py
rename backend/app/core/tools/{googletranslate => google}/googletranslate.py (100%)
create mode 100644 backend/app/core/tools/math/__init__.py
create mode 100644 backend/app/core/tools/openweather/__init__.py
create mode 100644 backend/app/core/tools/siliconflow/__init__.py
rename backend/app/core/tools/siliconflow/{siliconflow.py => siliconflow_img_gen.py} (95%)
create mode 100644 backend/app/core/tools/spark/__init__.py
rename backend/app/core/tools/spark/{spark.py => spark_img_gen.py} (98%)
create mode 100644 backend/app/core/tools/tool.py
create mode 100644 backend/app/core/tools/zhipuai/__init__.py
create mode 100644 backend/app/core/tools/zhipuai/img_4v.py
diff --git a/.env.example b/.env.example
index 7193fc38..9156d2ea 100644
--- a/.env.example
+++ b/.env.example
@@ -88,4 +88,6 @@ OPEN_WEATHER_API_KEY=changethis
# Spark get apikey from https://console.xfyun.cn/app/myapp
SPARK_APPID=changethis
SPARK_APISecret=changethis
-SPARK_APIKey=changethis
\ No newline at end of file
+SPARK_APIKey=changethis
+
+ZHIPUAI_API_KEY=changethis
\ No newline at end of file
diff --git a/backend/app/api/routes/providermodel.py b/backend/app/api/routes/providermodel.py
index fe5c70a6..c9680e67 100644
--- a/backend/app/api/routes/providermodel.py
+++ b/backend/app/api/routes/providermodel.py
@@ -2,11 +2,11 @@
from app.api.deps import SessionDep
from app.curd.models import (
- create_model,
- delete_model,
+ _create_model,
+ _delete_model,
get_all_models,
get_models_by_provider,
- update_model,
+ _update_model,
)
from app.models import Models, ModelsBase, ModelsOut
@@ -16,7 +16,7 @@
# Routes for Models
@router.post("/", response_model=ModelsBase)
def create_models(model: ModelsBase, session: SessionDep):
- return create_model(session, model)
+ return _create_model(session, model)
@router.get("/{provider_id}", response_model=ModelsOut)
@@ -31,7 +31,7 @@ def read_models(session: SessionDep):
@router.delete("/{model_id}", response_model=Models)
def delete_model(model_id: int, session: SessionDep):
- model = delete_model(session, model_id)
+ model = _delete_model(session, model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
return model
@@ -39,7 +39,7 @@ def delete_model(model_id: int, session: SessionDep):
@router.put("/{model_id}", response_model=Models)
def update_model(model_id: int, model_update: ModelsBase, session: SessionDep):
- model = update_model(session, model_id, model_update)
+ model = _update_model(session, model_id, model_update)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
return model
diff --git a/backend/app/core/db.py b/backend/app/core/db.py
index 5c449715..2f8a074a 100644
--- a/backend/app/core/db.py
+++ b/backend/app/core/db.py
@@ -111,9 +111,15 @@ def init_modelprovider_model_db(session: Session) -> None:
(3, 'gpt4o-mini', 4),
(4, 'llama3.1:8b', 1),
(5, 'Qwen/Qwen2-7B-Instruct', 2),
- (6, 'glm-4', 3),
- (7, 'glm-4-0520', 3),
- (8, 'glm-4-flash', 3)
+
+ (6, 'glm-4-alltools', 3),
+ (7, 'glm-4-flash', 3),
+ (8, 'glm-4-0520', 3),
+ (9, 'glm-4-plus', 3),
+ (10, 'glm-4v-plus', 3),
+ (11, 'glm-4', 3),
+ (12, 'glm-4v', 3)
+
ON CONFLICT (id) DO NOTHING;
"""
diff --git a/backend/app/core/tools/ask_human/__init__.py b/backend/app/core/tools/ask_human/__init__.py
new file mode 100644
index 00000000..6ca6290a
--- /dev/null
+++ b/backend/app/core/tools/ask_human/__init__.py
@@ -0,0 +1,3 @@
+from .ask_human import ask_human
+
+__all__ = ["ask_human"]
diff --git a/backend/app/core/tools/google/__init__.py b/backend/app/core/tools/google/__init__.py
new file mode 100644
index 00000000..e5bea197
--- /dev/null
+++ b/backend/app/core/tools/google/__init__.py
@@ -0,0 +1,3 @@
+from .googletranslate import googletranslate
+
+__all__ = ["googletranslate"]
diff --git a/backend/app/core/tools/googletranslate/googletranslate.py b/backend/app/core/tools/google/googletranslate.py
similarity index 100%
rename from backend/app/core/tools/googletranslate/googletranslate.py
rename to backend/app/core/tools/google/googletranslate.py
diff --git a/backend/app/core/tools/math/__init__.py b/backend/app/core/tools/math/__init__.py
new file mode 100644
index 00000000..128ab3ab
--- /dev/null
+++ b/backend/app/core/tools/math/__init__.py
@@ -0,0 +1,3 @@
+from .math import math
+
+__all__ = ["math"]
diff --git a/backend/app/core/tools/openweather/__init__.py b/backend/app/core/tools/openweather/__init__.py
new file mode 100644
index 00000000..768ddc12
--- /dev/null
+++ b/backend/app/core/tools/openweather/__init__.py
@@ -0,0 +1,3 @@
+from .openweather import openweather
+
+__all__ = ["openweather"]
diff --git a/backend/app/core/tools/siliconflow/__init__.py b/backend/app/core/tools/siliconflow/__init__.py
new file mode 100644
index 00000000..9f7faca7
--- /dev/null
+++ b/backend/app/core/tools/siliconflow/__init__.py
@@ -0,0 +1,3 @@
+from .siliconflow_img_gen import siliconflow_img_generation
+
+__all__ = ["siliconflow_img_generation"]
diff --git a/backend/app/core/tools/siliconflow/siliconflow.py b/backend/app/core/tools/siliconflow/siliconflow_img_gen.py
similarity index 95%
rename from backend/app/core/tools/siliconflow/siliconflow.py
rename to backend/app/core/tools/siliconflow/siliconflow_img_gen.py
index 04ce08bf..6982f99f 100644
--- a/backend/app/core/tools/siliconflow/siliconflow.py
+++ b/backend/app/core/tools/siliconflow/siliconflow_img_gen.py
@@ -42,7 +42,7 @@ def text2img(
return json.dumps(f"There is a error occured . {e}")
-siliconflow = StructuredTool.from_function(
+siliconflow_img_generation = StructuredTool.from_function(
func=text2img,
name="Image Generation",
description="Siliconflow Image Generation is a tool that can generate images from text prompts using the Siliconflow API.",
diff --git a/backend/app/core/tools/spark/__init__.py b/backend/app/core/tools/spark/__init__.py
new file mode 100644
index 00000000..f455622d
--- /dev/null
+++ b/backend/app/core/tools/spark/__init__.py
@@ -0,0 +1,3 @@
+from .spark_img_gen import spark_img_generation
+
+__all__ = ["spark_img_generation"]
diff --git a/backend/app/core/tools/spark/spark.py b/backend/app/core/tools/spark/spark_img_gen.py
similarity index 98%
rename from backend/app/core/tools/spark/spark.py
rename to backend/app/core/tools/spark/spark_img_gen.py
index d8572bb4..0c567e0c 100644
--- a/backend/app/core/tools/spark/spark.py
+++ b/backend/app/core/tools/spark/spark_img_gen.py
@@ -133,7 +133,7 @@ def img_generation(prompt):
return json.dumps(f"There is a error occured . {e}")
-spark = StructuredTool.from_function(
+spark_img_generation = StructuredTool.from_function(
func=img_generation,
name="Spark Image Generation",
description="Spark Image Generation is a tool that can generate images from text prompts using the Spark API.",
diff --git a/backend/app/core/tools/tool.py b/backend/app/core/tools/tool.py
new file mode 100644
index 00000000..8dd9bcda
--- /dev/null
+++ b/backend/app/core/tools/tool.py
@@ -0,0 +1,82 @@
+from enum import Enum
+from typing import Any, Optional, Union, cast
+
+from pydantic import BaseModel, Field, field_validator
+
+
+class ToolInvokeMessage(BaseModel):
+ class MessageType(Enum):
+ TEXT = "text"
+ IMAGE = "image"
+ LINK = "link"
+ BLOB = "blob"
+ JSON = "json"
+ IMAGE_LINK = "image_link"
+
+ type: MessageType = MessageType.TEXT
+ """
+ plain text, image url or link url
+ """
+ message: Union[str, bytes, dict] = None
+ meta: dict[str, Any] = None
+ save_as: str = ""
+
+
+def create_image_message(image: str, save_as: str = "") -> ToolInvokeMessage:
+ """
+ create an image message
+
+ :param image: the url of the image
+ :return: the image message
+ """
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as
+ )
+
+
+def create_link_message(link: str, save_as: str = "") -> ToolInvokeMessage:
+ """
+ create a link message
+
+ :param link: the url of the link
+ :return: the link message
+ """
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as
+ )
+
+
+def create_text_message(text: str, save_as: str = "") -> ToolInvokeMessage:
+ """
+ create a text message
+
+ :param text: the text
+ :return: the text message
+ """
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as
+ )
+
+
+def create_blob_message(
+ blob: bytes, meta: dict = None, save_as: str = ""
+) -> ToolInvokeMessage:
+ """
+ create a blob message
+
+ :param blob: the blob
+ :return: the blob message
+ """
+ return ToolInvokeMessage(
+ type=ToolInvokeMessage.MessageType.BLOB,
+ message=blob,
+ meta=meta,
+ save_as=save_as,
+ )
+
+
+def create_json_message(object: dict) -> ToolInvokeMessage:
+ """
+ create a json message
+ """
+ return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object)
diff --git a/backend/app/core/tools/tool_manager.py b/backend/app/core/tools/tool_manager.py
index d09a8717..393f1144 100644
--- a/backend/app/core/tools/tool_manager.py
+++ b/backend/app/core/tools/tool_manager.py
@@ -30,20 +30,27 @@ def load_tools(self):
):
try:
module = importlib.import_module(
- f".{item}.{item}", package="app.core.tools"
+ f".{item}", package="app.core.tools"
)
- # Try to get the tool instance directly
- tool_instance = getattr(module, item)
- if isinstance(tool_instance, BaseTool):
- formatted_name = self.format_tool_name(item)
- self.managed_tools[formatted_name] = ToolInfo(
- description=tool_instance.description,
- tool=tool_instance,
- display_name=tool_instance.name,
- )
+ # Check if __all__ is defined in the module
+ if hasattr(module, "__all__"):
+ for tool_name in module.__all__:
+ tool_instance = getattr(module, tool_name, None)
+ if isinstance(tool_instance, BaseTool):
+ formatted_name = self.format_tool_name(item)
+ self.managed_tools[formatted_name] = ToolInfo(
+ description=tool_instance.description,
+ tool=tool_instance,
+ display_name=tool_instance.name,
+ )
+ else:
+ print(
+ f"Warning: {tool_name} in {item} is not an instance of BaseTool"
+ )
else:
- print(f"Warning: {item} is not an instance of BaseTool")
+ print(f"Warning: {item} does not define __all__")
+
except (ImportError, AttributeError) as e:
print(f"Failed to load tool {item}: {e}")
diff --git a/backend/app/core/tools/zhipuai/__init__.py b/backend/app/core/tools/zhipuai/__init__.py
new file mode 100644
index 00000000..4446dc98
--- /dev/null
+++ b/backend/app/core/tools/zhipuai/__init__.py
@@ -0,0 +1,3 @@
+from .img_4v import img_understanding
+
+__all__ = ["img_understanding"]
diff --git a/backend/app/core/tools/zhipuai/img_4v.py b/backend/app/core/tools/zhipuai/img_4v.py
new file mode 100644
index 00000000..5109dcac
--- /dev/null
+++ b/backend/app/core/tools/zhipuai/img_4v.py
@@ -0,0 +1,45 @@
+import base64
+from zhipuai import ZhipuAI
+import os
+from langchain.pydantic_v1 import BaseModel, Field
+
+from langchain.tools import StructuredTool
+
+
+class ImageUnderstandingInput(BaseModel):
+ """Input for the Image Understanding tool."""
+
+ text: str = Field(description="the input text for the Image Understanding tool")
+
+
+def img_4v(text: str):
+ img_path = "/Users/envys/Downloads/a.jpeg"
+ with open(img_path, "rb") as img_file:
+ img_base = base64.b64encode(img_file.read()).decode("utf-8")
+
+ client = ZhipuAI(
+ api_key=os.environ.get("ZHIPUAI_API_KEY"),
+ ) # 填写您自己的APIKey
+ response = client.chat.completions.create(
+ model="glm-4v", # 填写需要调用的模型名称
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {"type": "image_url", "image_url": {"url": img_base}},
+ {"type": "text", "text": text},
+ ],
+ }
+ ],
+ )
+
+ return response.choices[0].message
+
+
+img_understanding = StructuredTool.from_function(
+ func=img_4v,
+ name="Image Understanding",
+ description="Users input an image and a question, and the LLM can identify objects, scenes, and other information in the image to answer the user's question.",
+ args_schema=ImageUnderstandingInput,
+ return_direct=True,
+)
diff --git a/backend/app/curd/models.py b/backend/app/curd/models.py
index 2485aadd..3143d3cd 100644
--- a/backend/app/curd/models.py
+++ b/backend/app/curd/models.py
@@ -5,7 +5,7 @@
from ..models import ModelOut, ModelProviderOut, Models, ModelsBase, ModelsOut
-def create_model(session: Session, model: ModelsBase) -> Models:
+def _create_model(session: Session, model: ModelsBase) -> Models:
try:
db_model = Models(**model.model_dump())
session.add(db_model)
@@ -92,7 +92,7 @@ def get_all_models(session: Session) -> ModelsOut:
return ModelsOut(data=model_outs, count=total_count)
-def delete_model(session: Session, model_id: int) -> Optional[Models]:
+def _delete_model(session: Session, model_id: int) -> Optional[Models]:
db_model = session.exec(select(Models).where(Models.id == model_id)).first()
if db_model:
session.delete(db_model)
@@ -100,7 +100,7 @@ def delete_model(session: Session, model_id: int) -> Optional[Models]:
return db_model
-def update_model(
+def _update_model(
session: Session, model_id: int, model_update: ModelsBase
) -> Optional[Models]:
db_model = session.exec(select(Models).where(Models.id == model_id)).first()
diff --git a/backend/pyproject.toml b/backend/pyproject.toml
index a3d09520..45d2ca9b 100644
--- a/backend/pyproject.toml
+++ b/backend/pyproject.toml
@@ -14,7 +14,8 @@ passlib = {extras = ["bcrypt"], version = "^1.7.4"}
tenacity = "^8.2.3"
pydantic = ">2.0"
emails = "^0.6"
-
+zhipuai = "2.1.5.20230904"
+numexpr = "2.10.1"
gunicorn = "^22.0.0"
jinja2 = "^3.1.4"
alembic = "^1.12.1"
diff --git a/web/src/app/(applayout)/teams/[teamId]/page.tsx b/web/src/app/(applayout)/teams/[teamId]/page.tsx
index 4683daa3..90e187fd 100644
--- a/web/src/app/(applayout)/teams/[teamId]/page.tsx
+++ b/web/src/app/(applayout)/teams/[teamId]/page.tsx
@@ -10,14 +10,10 @@ import {
Spinner,
} from "@chakra-ui/react";
import { useQuery } from "react-query";
-
import Flow from "@/components/ReactFlow/Flow";
-
-import ShowFlow from "@/app/flow/show/ShowFlow";
import DebugPreview from "@/components/Teams/DebugPreview";
import NormalTeamSettings from "@/components/Teams/NormalTeamSettings";
import WorkflowTeamSettings from "@/components/Teams/WorkflowTeamSettings";
-import { color } from "framer-motion";
import Link from "next/link";
import { useParams } from "next/navigation";
import { useRef } from "react";
@@ -32,7 +28,7 @@ function Team() {
isError,
error,
} = useQuery(`team/${teamId}`, () =>
- TeamsService.readTeam({ id: Number.parseInt(teamId) }),
+ TeamsService.readTeam({ id: Number.parseInt(teamId) })
);
if (isError) {
@@ -118,7 +114,7 @@ function Team() {
) : (
-
+
)}
diff --git a/web/src/app/(applayout)/tools/page.tsx b/web/src/app/(applayout)/tools/page.tsx
index 557adab6..13de8e30 100644
--- a/web/src/app/(applayout)/tools/page.tsx
+++ b/web/src/app/(applayout)/tools/page.tsx
@@ -2,7 +2,7 @@
import { type ApiError, ToolsService } from "@/client";
import ActionsMenu from "@/components/Common/ActionsMenu";
import {
- Badge,
+ Text,
Box,
Flex,
Heading,
@@ -125,14 +125,19 @@ function Skills() {
.replace(/ /g, "_")}
/>
- {skill.display_name}
+
+ {skill.display_name}
+
- {skill.description}
+
+ {skill.description}
+
{!skill.managed ? (
diff --git a/web/src/components/Icons/Tools/index.tsx b/web/src/components/Icons/Tools/index.tsx
index bc93675e..b9a95b43 100644
--- a/web/src/components/Icons/Tools/index.tsx
+++ b/web/src/components/Icons/Tools/index.tsx
@@ -1,5 +1,5 @@
import { Icon, type IconProps, createIcon } from "@chakra-ui/icons";
-import { SiliconFlowIcon } from "../models";
+import { SiliconFlowIcon, ZhipuAIIcon } from "../models";
const OpenWeather = createIcon({
displayName: "OpenWeather",
@@ -367,6 +367,7 @@ const iconMap: { [key: string]: React.FC } = {
math_calculator: calculator,
image_generation: SiliconFlowIcon,
spark_image_generation: spark,
+ image_understanding: ZhipuAIIcon,
};
const DefaultIcon = Wikipedia;
diff --git a/web/src/components/Teams/DebugPreview/index.tsx b/web/src/components/Teams/DebugPreview/index.tsx
index 7173bb28..d8d9888b 100644
--- a/web/src/components/Teams/DebugPreview/index.tsx
+++ b/web/src/components/Teams/DebugPreview/index.tsx
@@ -19,6 +19,7 @@ function DebugPreview({
borderRadius={"lg"}
display={"flex"}
flexDirection={"column"}
+ overflow={"hidden"}
>
-
+
{member?.map((member) => (