Skip to content

Commit

Permalink
feat:add img under standing tool and refactor code to add toolkits su…
Browse files Browse the repository at this point in the history
…pport (#57)

* fix:rename model curd function name to solve the confilct between curd and router

* feat:add img understanding tools

* feat:support toolkits

* fix:img understanding icon
  • Loading branch information
Onelevenvy authored Oct 1, 2024
1 parent 7a664e3 commit 4993b9b
Show file tree
Hide file tree
Showing 23 changed files with 208 additions and 40 deletions.
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ OPEN_WEATHER_API_KEY=changethis
# Spark get apikey from https://console.xfyun.cn/app/myapp
SPARK_APPID=changethis
SPARK_APISecret=changethis
SPARK_APIKey=changethis
SPARK_APIKey=changethis

ZHIPUAI_API_KEY=changethis
12 changes: 6 additions & 6 deletions backend/app/api/routes/providermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from app.api.deps import SessionDep
from app.curd.models import (
create_model,
delete_model,
_create_model,
_delete_model,
get_all_models,
get_models_by_provider,
update_model,
_update_model,
)
from app.models import Models, ModelsBase, ModelsOut

Expand All @@ -16,7 +16,7 @@
# Routes for Models
@router.post("/", response_model=ModelsBase)
def create_models(model: ModelsBase, session: SessionDep):
return create_model(session, model)
return _create_model(session, model)


@router.get("/{provider_id}", response_model=ModelsOut)
Expand All @@ -31,15 +31,15 @@ def read_models(session: SessionDep):

@router.delete("/{model_id}", response_model=Models)
def delete_model(model_id: int, session: SessionDep):
model = delete_model(session, model_id)
model = _delete_model(session, model_id)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
return model


@router.put("/{model_id}", response_model=Models)
def update_model(model_id: int, model_update: ModelsBase, session: SessionDep):
model = update_model(session, model_id, model_update)
model = _update_model(session, model_id, model_update)
if model is None:
raise HTTPException(status_code=404, detail="Model not found")
return model
12 changes: 9 additions & 3 deletions backend/app/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,15 @@ def init_modelprovider_model_db(session: Session) -> None:
(3, 'gpt4o-mini', 4),
(4, 'llama3.1:8b', 1),
(5, 'Qwen/Qwen2-7B-Instruct', 2),
(6, 'glm-4', 3),
(7, 'glm-4-0520', 3),
(8, 'glm-4-flash', 3)
(6, 'glm-4-alltools', 3),
(7, 'glm-4-flash', 3),
(8, 'glm-4-0520', 3),
(9, 'glm-4-plus', 3),
(10, 'glm-4v-plus', 3),
(11, 'glm-4', 3),
(12, 'glm-4v', 3)
ON CONFLICT (id) DO NOTHING;
"""

Expand Down
3 changes: 3 additions & 0 deletions backend/app/core/tools/ask_human/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ask_human import ask_human

__all__ = ["ask_human"]
3 changes: 3 additions & 0 deletions backend/app/core/tools/google/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .googletranslate import googletranslate

__all__ = ["googletranslate"]
3 changes: 3 additions & 0 deletions backend/app/core/tools/math/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .math import math

__all__ = ["math"]
3 changes: 3 additions & 0 deletions backend/app/core/tools/openweather/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .openweather import openweather

__all__ = ["openweather"]
3 changes: 3 additions & 0 deletions backend/app/core/tools/siliconflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .siliconflow_img_gen import siliconflow_img_generation

__all__ = ["siliconflow_img_generation"]
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def text2img(
return json.dumps(f"There is a error occured . {e}")


siliconflow = StructuredTool.from_function(
siliconflow_img_generation = StructuredTool.from_function(
func=text2img,
name="Image Generation",
description="Siliconflow Image Generation is a tool that can generate images from text prompts using the Siliconflow API.",
Expand Down
3 changes: 3 additions & 0 deletions backend/app/core/tools/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .spark_img_gen import spark_img_generation

__all__ = ["spark_img_generation"]
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def img_generation(prompt):
return json.dumps(f"There is a error occured . {e}")


spark = StructuredTool.from_function(
spark_img_generation = StructuredTool.from_function(
func=img_generation,
name="Spark Image Generation",
description="Spark Image Generation is a tool that can generate images from text prompts using the Spark API.",
Expand Down
82 changes: 82 additions & 0 deletions backend/app/core/tools/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from enum import Enum
from typing import Any, Optional, Union, cast

from pydantic import BaseModel, Field, field_validator


class ToolInvokeMessage(BaseModel):
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
LINK = "link"
BLOB = "blob"
JSON = "json"
IMAGE_LINK = "image_link"

type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: Union[str, bytes, dict] = None
meta: dict[str, Any] = None
save_as: str = ""


def create_image_message(image: str, save_as: str = "") -> ToolInvokeMessage:
"""
create an image message
:param image: the url of the image
:return: the image message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as
)


def create_link_message(link: str, save_as: str = "") -> ToolInvokeMessage:
"""
create a link message
:param link: the url of the link
:return: the link message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK, message=link, save_as=save_as
)


def create_text_message(text: str, save_as: str = "") -> ToolInvokeMessage:
"""
create a text message
:param text: the text
:return: the text message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as
)


def create_blob_message(
blob: bytes, meta: dict = None, save_as: str = ""
) -> ToolInvokeMessage:
"""
create a blob message
:param blob: the blob
:return: the blob message
"""
return ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=blob,
meta=meta,
save_as=save_as,
)


def create_json_message(object: dict) -> ToolInvokeMessage:
"""
create a json message
"""
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object)
29 changes: 18 additions & 11 deletions backend/app/core/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,27 @@ def load_tools(self):
):
try:
module = importlib.import_module(
f".{item}.{item}", package="app.core.tools"
f".{item}", package="app.core.tools"
)
# Try to get the tool instance directly
tool_instance = getattr(module, item)
if isinstance(tool_instance, BaseTool):
formatted_name = self.format_tool_name(item)
self.managed_tools[formatted_name] = ToolInfo(
description=tool_instance.description,
tool=tool_instance,
display_name=tool_instance.name,
)

# Check if __all__ is defined in the module
if hasattr(module, "__all__"):
for tool_name in module.__all__:
tool_instance = getattr(module, tool_name, None)
if isinstance(tool_instance, BaseTool):
formatted_name = self.format_tool_name(item)
self.managed_tools[formatted_name] = ToolInfo(
description=tool_instance.description,
tool=tool_instance,
display_name=tool_instance.name,
)
else:
print(
f"Warning: {tool_name} in {item} is not an instance of BaseTool"
)
else:
print(f"Warning: {item} is not an instance of BaseTool")
print(f"Warning: {item} does not define __all__")

except (ImportError, AttributeError) as e:
print(f"Failed to load tool {item}: {e}")

Expand Down
3 changes: 3 additions & 0 deletions backend/app/core/tools/zhipuai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .img_4v import img_understanding

__all__ = ["img_understanding"]
45 changes: 45 additions & 0 deletions backend/app/core/tools/zhipuai/img_4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import base64
from zhipuai import ZhipuAI
import os
from langchain.pydantic_v1 import BaseModel, Field

from langchain.tools import StructuredTool


class ImageUnderstandingInput(BaseModel):
"""Input for the Image Understanding tool."""

text: str = Field(description="the input text for the Image Understanding tool")


def img_4v(text: str):
img_path = "/Users/envys/Downloads/a.jpeg"
with open(img_path, "rb") as img_file:
img_base = base64.b64encode(img_file.read()).decode("utf-8")

client = ZhipuAI(
api_key=os.environ.get("ZHIPUAI_API_KEY"),
) # 填写您自己的APIKey
response = client.chat.completions.create(
model="glm-4v", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": img_base}},
{"type": "text", "text": text},
],
}
],
)

return response.choices[0].message


img_understanding = StructuredTool.from_function(
func=img_4v,
name="Image Understanding",
description="Users input an image and a question, and the LLM can identify objects, scenes, and other information in the image to answer the user's question.",
args_schema=ImageUnderstandingInput,
return_direct=True,
)
6 changes: 3 additions & 3 deletions backend/app/curd/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..models import ModelOut, ModelProviderOut, Models, ModelsBase, ModelsOut


def create_model(session: Session, model: ModelsBase) -> Models:
def _create_model(session: Session, model: ModelsBase) -> Models:
try:
db_model = Models(**model.model_dump())
session.add(db_model)
Expand Down Expand Up @@ -92,15 +92,15 @@ def get_all_models(session: Session) -> ModelsOut:
return ModelsOut(data=model_outs, count=total_count)


def delete_model(session: Session, model_id: int) -> Optional[Models]:
def _delete_model(session: Session, model_id: int) -> Optional[Models]:
db_model = session.exec(select(Models).where(Models.id == model_id)).first()
if db_model:
session.delete(db_model)
session.commit()
return db_model


def update_model(
def _update_model(
session: Session, model_id: int, model_update: ModelsBase
) -> Optional[Models]:
db_model = session.exec(select(Models).where(Models.id == model_id)).first()
Expand Down
3 changes: 2 additions & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ passlib = {extras = ["bcrypt"], version = "^1.7.4"}
tenacity = "^8.2.3"
pydantic = ">2.0"
emails = "^0.6"

zhipuai = "2.1.5.20230904"
numexpr = "2.10.1"
gunicorn = "^22.0.0"
jinja2 = "^3.1.4"
alembic = "^1.12.1"
Expand Down
8 changes: 2 additions & 6 deletions web/src/app/(applayout)/teams/[teamId]/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ import {
Spinner,
} from "@chakra-ui/react";
import { useQuery } from "react-query";

import Flow from "@/components/ReactFlow/Flow";

import ShowFlow from "@/app/flow/show/ShowFlow";
import DebugPreview from "@/components/Teams/DebugPreview";
import NormalTeamSettings from "@/components/Teams/NormalTeamSettings";
import WorkflowTeamSettings from "@/components/Teams/WorkflowTeamSettings";
import { color } from "framer-motion";
import Link from "next/link";
import { useParams } from "next/navigation";
import { useRef } from "react";
Expand All @@ -32,7 +28,7 @@ function Team() {
isError,
error,
} = useQuery(`team/${teamId}`, () =>
TeamsService.readTeam({ id: Number.parseInt(teamId) }),
TeamsService.readTeam({ id: Number.parseInt(teamId) })
);

if (isError) {
Expand Down Expand Up @@ -118,7 +114,7 @@ function Team() {
<WorkflowTeamSettings teamId={Number.parseInt(teamId)} />
</Box>
) : (
<Box h="full" maxH={"full"} borderRadius="md">
<Box h="full" maxH={"full"} borderRadius="md" minH={"full"}>
<NormalTeamSettings teamData={team} />
</Box>
)}
Expand Down
15 changes: 10 additions & 5 deletions web/src/app/(applayout)/tools/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { type ApiError, ToolsService } from "@/client";
import ActionsMenu from "@/components/Common/ActionsMenu";
import {
Badge,
Text,
Box,
Flex,
Heading,
Expand Down Expand Up @@ -125,14 +125,19 @@ function Skills() {
.replace(/ /g, "_")}
/>

<Heading size="md">{skill.display_name}</Heading>
<Heading noOfLines={1} size="md">
{skill.display_name}
</Heading>
</HStack>
<Box
overflow="hidden"
textOverflow="ellipsis"
// whiteSpace="nowrap"
minH={"55px"}
h={"55px"}
maxH={"55px"}
>
{skill.description}
<Text textOverflow="ellipsis" noOfLines={2}>
{skill.description}
</Text>
</Box>
<Box pt={4}>
{!skill.managed ? (
Expand Down
Loading

0 comments on commit 4993b9b

Please sign in to comment.