From 4993b9b362281a8375d58fa3b9c5f52eeb3dab33 Mon Sep 17 00:00:00 2001 From: Onelevenvy <49232224+Onelevenvy@users.noreply.github.com> Date: Tue, 1 Oct 2024 11:37:49 +0800 Subject: [PATCH] feat:add img under standing tool and refactor code to add toolkits support (#57) * fix:rename model curd function name to solve the confilct between curd and router * feat:add img understanding tools * feat:support toolkits * fix:img understanding icon --- .env.example | 4 +- backend/app/api/routes/providermodel.py | 12 +-- backend/app/core/db.py | 12 ++- backend/app/core/tools/ask_human/__init__.py | 3 + backend/app/core/tools/google/__init__.py | 3 + .../googletranslate.py | 0 backend/app/core/tools/math/__init__.py | 3 + .../app/core/tools/openweather/__init__.py | 3 + .../app/core/tools/siliconflow/__init__.py | 3 + ...{siliconflow.py => siliconflow_img_gen.py} | 2 +- backend/app/core/tools/spark/__init__.py | 3 + .../spark/{spark.py => spark_img_gen.py} | 2 +- backend/app/core/tools/tool.py | 82 +++++++++++++++++++ backend/app/core/tools/tool_manager.py | 29 ++++--- backend/app/core/tools/zhipuai/__init__.py | 3 + backend/app/core/tools/zhipuai/img_4v.py | 45 ++++++++++ backend/app/curd/models.py | 6 +- backend/pyproject.toml | 3 +- .../app/(applayout)/teams/[teamId]/page.tsx | 8 +- web/src/app/(applayout)/tools/page.tsx | 15 ++-- web/src/components/Icons/Tools/index.tsx | 3 +- .../components/Teams/DebugPreview/index.tsx | 1 + .../components/Teams/NormalTeamSettings.tsx | 3 +- 23 files changed, 208 insertions(+), 40 deletions(-) create mode 100644 backend/app/core/tools/ask_human/__init__.py create mode 100644 backend/app/core/tools/google/__init__.py rename backend/app/core/tools/{googletranslate => google}/googletranslate.py (100%) create mode 100644 backend/app/core/tools/math/__init__.py create mode 100644 backend/app/core/tools/openweather/__init__.py create mode 100644 backend/app/core/tools/siliconflow/__init__.py rename backend/app/core/tools/siliconflow/{siliconflow.py => siliconflow_img_gen.py} (95%) create mode 100644 backend/app/core/tools/spark/__init__.py rename backend/app/core/tools/spark/{spark.py => spark_img_gen.py} (98%) create mode 100644 backend/app/core/tools/tool.py create mode 100644 backend/app/core/tools/zhipuai/__init__.py create mode 100644 backend/app/core/tools/zhipuai/img_4v.py diff --git a/.env.example b/.env.example index 7193fc38..9156d2ea 100644 --- a/.env.example +++ b/.env.example @@ -88,4 +88,6 @@ OPEN_WEATHER_API_KEY=changethis # Spark get apikey from https://console.xfyun.cn/app/myapp SPARK_APPID=changethis SPARK_APISecret=changethis -SPARK_APIKey=changethis \ No newline at end of file +SPARK_APIKey=changethis + +ZHIPUAI_API_KEY=changethis \ No newline at end of file diff --git a/backend/app/api/routes/providermodel.py b/backend/app/api/routes/providermodel.py index fe5c70a6..c9680e67 100644 --- a/backend/app/api/routes/providermodel.py +++ b/backend/app/api/routes/providermodel.py @@ -2,11 +2,11 @@ from app.api.deps import SessionDep from app.curd.models import ( - create_model, - delete_model, + _create_model, + _delete_model, get_all_models, get_models_by_provider, - update_model, + _update_model, ) from app.models import Models, ModelsBase, ModelsOut @@ -16,7 +16,7 @@ # Routes for Models @router.post("/", response_model=ModelsBase) def create_models(model: ModelsBase, session: SessionDep): - return create_model(session, model) + return _create_model(session, model) @router.get("/{provider_id}", response_model=ModelsOut) @@ -31,7 +31,7 @@ def read_models(session: SessionDep): @router.delete("/{model_id}", response_model=Models) def delete_model(model_id: int, session: SessionDep): - model = delete_model(session, model_id) + model = _delete_model(session, model_id) if model is None: raise HTTPException(status_code=404, detail="Model not found") return model @@ -39,7 +39,7 @@ def delete_model(model_id: int, session: SessionDep): @router.put("/{model_id}", response_model=Models) def update_model(model_id: int, model_update: ModelsBase, session: SessionDep): - model = update_model(session, model_id, model_update) + model = _update_model(session, model_id, model_update) if model is None: raise HTTPException(status_code=404, detail="Model not found") return model diff --git a/backend/app/core/db.py b/backend/app/core/db.py index 5c449715..2f8a074a 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -111,9 +111,15 @@ def init_modelprovider_model_db(session: Session) -> None: (3, 'gpt4o-mini', 4), (4, 'llama3.1:8b', 1), (5, 'Qwen/Qwen2-7B-Instruct', 2), - (6, 'glm-4', 3), - (7, 'glm-4-0520', 3), - (8, 'glm-4-flash', 3) + + (6, 'glm-4-alltools', 3), + (7, 'glm-4-flash', 3), + (8, 'glm-4-0520', 3), + (9, 'glm-4-plus', 3), + (10, 'glm-4v-plus', 3), + (11, 'glm-4', 3), + (12, 'glm-4v', 3) + ON CONFLICT (id) DO NOTHING; """ diff --git a/backend/app/core/tools/ask_human/__init__.py b/backend/app/core/tools/ask_human/__init__.py new file mode 100644 index 00000000..6ca6290a --- /dev/null +++ b/backend/app/core/tools/ask_human/__init__.py @@ -0,0 +1,3 @@ +from .ask_human import ask_human + +__all__ = ["ask_human"] diff --git a/backend/app/core/tools/google/__init__.py b/backend/app/core/tools/google/__init__.py new file mode 100644 index 00000000..e5bea197 --- /dev/null +++ b/backend/app/core/tools/google/__init__.py @@ -0,0 +1,3 @@ +from .googletranslate import googletranslate + +__all__ = ["googletranslate"] diff --git a/backend/app/core/tools/googletranslate/googletranslate.py b/backend/app/core/tools/google/googletranslate.py similarity index 100% rename from backend/app/core/tools/googletranslate/googletranslate.py rename to backend/app/core/tools/google/googletranslate.py diff --git a/backend/app/core/tools/math/__init__.py b/backend/app/core/tools/math/__init__.py new file mode 100644 index 00000000..128ab3ab --- /dev/null +++ b/backend/app/core/tools/math/__init__.py @@ -0,0 +1,3 @@ +from .math import math + +__all__ = ["math"] diff --git a/backend/app/core/tools/openweather/__init__.py b/backend/app/core/tools/openweather/__init__.py new file mode 100644 index 00000000..768ddc12 --- /dev/null +++ b/backend/app/core/tools/openweather/__init__.py @@ -0,0 +1,3 @@ +from .openweather import openweather + +__all__ = ["openweather"] diff --git a/backend/app/core/tools/siliconflow/__init__.py b/backend/app/core/tools/siliconflow/__init__.py new file mode 100644 index 00000000..9f7faca7 --- /dev/null +++ b/backend/app/core/tools/siliconflow/__init__.py @@ -0,0 +1,3 @@ +from .siliconflow_img_gen import siliconflow_img_generation + +__all__ = ["siliconflow_img_generation"] diff --git a/backend/app/core/tools/siliconflow/siliconflow.py b/backend/app/core/tools/siliconflow/siliconflow_img_gen.py similarity index 95% rename from backend/app/core/tools/siliconflow/siliconflow.py rename to backend/app/core/tools/siliconflow/siliconflow_img_gen.py index 04ce08bf..6982f99f 100644 --- a/backend/app/core/tools/siliconflow/siliconflow.py +++ b/backend/app/core/tools/siliconflow/siliconflow_img_gen.py @@ -42,7 +42,7 @@ def text2img( return json.dumps(f"There is a error occured . {e}") -siliconflow = StructuredTool.from_function( +siliconflow_img_generation = StructuredTool.from_function( func=text2img, name="Image Generation", description="Siliconflow Image Generation is a tool that can generate images from text prompts using the Siliconflow API.", diff --git a/backend/app/core/tools/spark/__init__.py b/backend/app/core/tools/spark/__init__.py new file mode 100644 index 00000000..f455622d --- /dev/null +++ b/backend/app/core/tools/spark/__init__.py @@ -0,0 +1,3 @@ +from .spark_img_gen import spark_img_generation + +__all__ = ["spark_img_generation"] diff --git a/backend/app/core/tools/spark/spark.py b/backend/app/core/tools/spark/spark_img_gen.py similarity index 98% rename from backend/app/core/tools/spark/spark.py rename to backend/app/core/tools/spark/spark_img_gen.py index d8572bb4..0c567e0c 100644 --- a/backend/app/core/tools/spark/spark.py +++ b/backend/app/core/tools/spark/spark_img_gen.py @@ -133,7 +133,7 @@ def img_generation(prompt): return json.dumps(f"There is a error occured . {e}") -spark = StructuredTool.from_function( +spark_img_generation = StructuredTool.from_function( func=img_generation, name="Spark Image Generation", description="Spark Image Generation is a tool that can generate images from text prompts using the Spark API.", diff --git a/backend/app/core/tools/tool.py b/backend/app/core/tools/tool.py new file mode 100644 index 00000000..8dd9bcda --- /dev/null +++ b/backend/app/core/tools/tool.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Any, Optional, Union, cast + +from pydantic import BaseModel, Field, field_validator + + +class ToolInvokeMessage(BaseModel): + class MessageType(Enum): + TEXT = "text" + IMAGE = "image" + LINK = "link" + BLOB = "blob" + JSON = "json" + IMAGE_LINK = "image_link" + + type: MessageType = MessageType.TEXT + """ + plain text, image url or link url + """ + message: Union[str, bytes, dict] = None + meta: dict[str, Any] = None + save_as: str = "" + + +def create_image_message(image: str, save_as: str = "") -> ToolInvokeMessage: + """ + create an image message + + :param image: the url of the image + :return: the image message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as + ) + + +def create_link_message(link: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a link message + + :param link: the url of the link + :return: the link message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as + ) + + +def create_text_message(text: str, save_as: str = "") -> ToolInvokeMessage: + """ + create a text message + + :param text: the text + :return: the text message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as + ) + + +def create_blob_message( + blob: bytes, meta: dict = None, save_as: str = "" +) -> ToolInvokeMessage: + """ + create a blob message + + :param blob: the blob + :return: the blob message + """ + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=blob, + meta=meta, + save_as=save_as, + ) + + +def create_json_message(object: dict) -> ToolInvokeMessage: + """ + create a json message + """ + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object) diff --git a/backend/app/core/tools/tool_manager.py b/backend/app/core/tools/tool_manager.py index d09a8717..393f1144 100644 --- a/backend/app/core/tools/tool_manager.py +++ b/backend/app/core/tools/tool_manager.py @@ -30,20 +30,27 @@ def load_tools(self): ): try: module = importlib.import_module( - f".{item}.{item}", package="app.core.tools" + f".{item}", package="app.core.tools" ) - # Try to get the tool instance directly - tool_instance = getattr(module, item) - if isinstance(tool_instance, BaseTool): - formatted_name = self.format_tool_name(item) - self.managed_tools[formatted_name] = ToolInfo( - description=tool_instance.description, - tool=tool_instance, - display_name=tool_instance.name, - ) + # Check if __all__ is defined in the module + if hasattr(module, "__all__"): + for tool_name in module.__all__: + tool_instance = getattr(module, tool_name, None) + if isinstance(tool_instance, BaseTool): + formatted_name = self.format_tool_name(item) + self.managed_tools[formatted_name] = ToolInfo( + description=tool_instance.description, + tool=tool_instance, + display_name=tool_instance.name, + ) + else: + print( + f"Warning: {tool_name} in {item} is not an instance of BaseTool" + ) else: - print(f"Warning: {item} is not an instance of BaseTool") + print(f"Warning: {item} does not define __all__") + except (ImportError, AttributeError) as e: print(f"Failed to load tool {item}: {e}") diff --git a/backend/app/core/tools/zhipuai/__init__.py b/backend/app/core/tools/zhipuai/__init__.py new file mode 100644 index 00000000..4446dc98 --- /dev/null +++ b/backend/app/core/tools/zhipuai/__init__.py @@ -0,0 +1,3 @@ +from .img_4v import img_understanding + +__all__ = ["img_understanding"] diff --git a/backend/app/core/tools/zhipuai/img_4v.py b/backend/app/core/tools/zhipuai/img_4v.py new file mode 100644 index 00000000..5109dcac --- /dev/null +++ b/backend/app/core/tools/zhipuai/img_4v.py @@ -0,0 +1,45 @@ +import base64 +from zhipuai import ZhipuAI +import os +from langchain.pydantic_v1 import BaseModel, Field + +from langchain.tools import StructuredTool + + +class ImageUnderstandingInput(BaseModel): + """Input for the Image Understanding tool.""" + + text: str = Field(description="the input text for the Image Understanding tool") + + +def img_4v(text: str): + img_path = "/Users/envys/Downloads/a.jpeg" + with open(img_path, "rb") as img_file: + img_base = base64.b64encode(img_file.read()).decode("utf-8") + + client = ZhipuAI( + api_key=os.environ.get("ZHIPUAI_API_KEY"), + ) # 填写您自己的APIKey + response = client.chat.completions.create( + model="glm-4v", # 填写需要调用的模型名称 + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": img_base}}, + {"type": "text", "text": text}, + ], + } + ], + ) + + return response.choices[0].message + + +img_understanding = StructuredTool.from_function( + func=img_4v, + name="Image Understanding", + description="Users input an image and a question, and the LLM can identify objects, scenes, and other information in the image to answer the user's question.", + args_schema=ImageUnderstandingInput, + return_direct=True, +) diff --git a/backend/app/curd/models.py b/backend/app/curd/models.py index 2485aadd..3143d3cd 100644 --- a/backend/app/curd/models.py +++ b/backend/app/curd/models.py @@ -5,7 +5,7 @@ from ..models import ModelOut, ModelProviderOut, Models, ModelsBase, ModelsOut -def create_model(session: Session, model: ModelsBase) -> Models: +def _create_model(session: Session, model: ModelsBase) -> Models: try: db_model = Models(**model.model_dump()) session.add(db_model) @@ -92,7 +92,7 @@ def get_all_models(session: Session) -> ModelsOut: return ModelsOut(data=model_outs, count=total_count) -def delete_model(session: Session, model_id: int) -> Optional[Models]: +def _delete_model(session: Session, model_id: int) -> Optional[Models]: db_model = session.exec(select(Models).where(Models.id == model_id)).first() if db_model: session.delete(db_model) @@ -100,7 +100,7 @@ def delete_model(session: Session, model_id: int) -> Optional[Models]: return db_model -def update_model( +def _update_model( session: Session, model_id: int, model_update: ModelsBase ) -> Optional[Models]: db_model = session.exec(select(Models).where(Models.id == model_id)).first() diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a3d09520..45d2ca9b 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -14,7 +14,8 @@ passlib = {extras = ["bcrypt"], version = "^1.7.4"} tenacity = "^8.2.3" pydantic = ">2.0" emails = "^0.6" - +zhipuai = "2.1.5.20230904" +numexpr = "2.10.1" gunicorn = "^22.0.0" jinja2 = "^3.1.4" alembic = "^1.12.1" diff --git a/web/src/app/(applayout)/teams/[teamId]/page.tsx b/web/src/app/(applayout)/teams/[teamId]/page.tsx index 4683daa3..90e187fd 100644 --- a/web/src/app/(applayout)/teams/[teamId]/page.tsx +++ b/web/src/app/(applayout)/teams/[teamId]/page.tsx @@ -10,14 +10,10 @@ import { Spinner, } from "@chakra-ui/react"; import { useQuery } from "react-query"; - import Flow from "@/components/ReactFlow/Flow"; - -import ShowFlow from "@/app/flow/show/ShowFlow"; import DebugPreview from "@/components/Teams/DebugPreview"; import NormalTeamSettings from "@/components/Teams/NormalTeamSettings"; import WorkflowTeamSettings from "@/components/Teams/WorkflowTeamSettings"; -import { color } from "framer-motion"; import Link from "next/link"; import { useParams } from "next/navigation"; import { useRef } from "react"; @@ -32,7 +28,7 @@ function Team() { isError, error, } = useQuery(`team/${teamId}`, () => - TeamsService.readTeam({ id: Number.parseInt(teamId) }), + TeamsService.readTeam({ id: Number.parseInt(teamId) }) ); if (isError) { @@ -118,7 +114,7 @@ function Team() { ) : ( - + )} diff --git a/web/src/app/(applayout)/tools/page.tsx b/web/src/app/(applayout)/tools/page.tsx index 557adab6..13de8e30 100644 --- a/web/src/app/(applayout)/tools/page.tsx +++ b/web/src/app/(applayout)/tools/page.tsx @@ -2,7 +2,7 @@ import { type ApiError, ToolsService } from "@/client"; import ActionsMenu from "@/components/Common/ActionsMenu"; import { - Badge, + Text, Box, Flex, Heading, @@ -125,14 +125,19 @@ function Skills() { .replace(/ /g, "_")} /> - {skill.display_name} + + {skill.display_name} + - {skill.description} + + {skill.description} + {!skill.managed ? ( diff --git a/web/src/components/Icons/Tools/index.tsx b/web/src/components/Icons/Tools/index.tsx index bc93675e..b9a95b43 100644 --- a/web/src/components/Icons/Tools/index.tsx +++ b/web/src/components/Icons/Tools/index.tsx @@ -1,5 +1,5 @@ import { Icon, type IconProps, createIcon } from "@chakra-ui/icons"; -import { SiliconFlowIcon } from "../models"; +import { SiliconFlowIcon, ZhipuAIIcon } from "../models"; const OpenWeather = createIcon({ displayName: "OpenWeather", @@ -367,6 +367,7 @@ const iconMap: { [key: string]: React.FC } = { math_calculator: calculator, image_generation: SiliconFlowIcon, spark_image_generation: spark, + image_understanding: ZhipuAIIcon, }; const DefaultIcon = Wikipedia; diff --git a/web/src/components/Teams/DebugPreview/index.tsx b/web/src/components/Teams/DebugPreview/index.tsx index 7173bb28..d8d9888b 100644 --- a/web/src/components/Teams/DebugPreview/index.tsx +++ b/web/src/components/Teams/DebugPreview/index.tsx @@ -19,6 +19,7 @@ function DebugPreview({ borderRadius={"lg"} display={"flex"} flexDirection={"column"} + overflow={"hidden"} > - + {member?.map((member) => (