Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into eugene/tlk-1990-create-documentat…
Browse files Browse the repository at this point in the history
…ion-for-debugger-settings-in-popular-ides
  • Loading branch information
EugeneLightsOn committed Nov 22, 2024
2 parents 48e2fb1 + 5502d8d commit 2a7ef6b
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 166 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ down:

.PHONY: run-unit-tests
run-unit-tests:
poetry run pytest src/backend/tests/unit --cov=src/backend --cov-report=xml
poetry run pytest src/backend/tests/unit/$(file) --cov=src/backend --cov-report=xml

.PHONY: run-community-tests
run-community-tests:
poetry run pytest src/community/tests --cov=src/community --cov-report=xml
poetry run pytest src/community/tests/$(file) --cov=src/community --cov-report=xml

.PHONY: run-integration-tests
run-integration-tests:
Expand Down
36 changes: 22 additions & 14 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
from backend.tools.base import ToolAuthException, ToolError, ToolErrorCode

TIMEOUT_SECONDS = 60

Expand Down Expand Up @@ -76,8 +77,8 @@ async def _call_tool_async(
tool_call: dict,
deployment_model: BaseDeployment,
) -> List[Dict[str, Any]]:
tool = get_available_tools().get(tool_call["name"])
if not tool:
tool_definition = get_available_tools().get(tool_call["name"])
if not tool_definition:
logger.info(
event=f"[Custom Chat] Tool not included in tools parameter: {tool_call['name']}",
)
Expand All @@ -89,8 +90,10 @@ async def _call_tool_async(
]
return outputs

tool = tool_definition.implementation()

try:
outputs = await tool.implementation().call(
outputs = await tool.call(
parameters=tool_call.get("parameters"),
ctx=ctx,
session=db,
Expand All @@ -101,23 +104,28 @@ async def _call_tool_async(
conversation_id=ctx.get_conversation_id(),
agent_tool_metadata=ctx.get_agent_tool_metadata(),
)
except ToolAuthException as e:
return [
{
"call": tool_call,
"outputs": tool.get_tool_error(
ToolError(
text="Tool authentication failed",
details=str(e),
type=ToolErrorCode.AUTH,
)
),
}
]
except Exception as e:
logger.exception(
event=f"[Custom Chat] Error while calling tool {tool_call['name']}: {str(e)}",
error=str(e),
)
outputs = [
return [
{
"call": tool_call,
"outputs": [{"error": str(e), "status_code": 500, "success": False}],
"outputs": tool.get_tool_error(ToolError(text=str(e))),
}
]
return outputs

# If the tool returns a list of outputs, append each output to the tool_results list
# Otherwise, append the single output to the tool_results list
outputs = outputs if isinstance(outputs, list) else [outputs]
tool_results = []
for output in outputs:
tool_results.append({"call": tool_call, "outputs": [output]})
return tool_results
return [{"call": tool_call, "outputs": outputs}]
9 changes: 3 additions & 6 deletions src/backend/tests/unit/chat/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ async def call(
"name": "toolkit_calculator",
"parameters": {"code": "6*7"},
},
"outputs": [
{"error": "Calculator failed", "status_code": 500, "success": False}
],
"outputs": [{'type': 'other', 'success': False, 'text': 'Calculator failed', 'details': ''}],
},
]

Expand Down Expand Up @@ -148,11 +146,10 @@ async def call(
results = asyncio.run(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)

assert {
"call": {"name": "web_scrape", "parameters": {"code": "6*7"}},
"outputs": [
{"error": "Web scrape failed", "status_code": 500, "success": False}
],
"outputs": [{"type": "other", 'success': False, 'text': 'Web scrape failed', 'details': ''}],
} in results
assert {
"call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}},
Expand Down
32 changes: 27 additions & 5 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import Any, Dict, List

from fastapi import Request
from pydantic import BaseModel

from backend.config.settings import Settings
from backend.crud import tool_auth as tool_auth_crud
Expand All @@ -13,6 +15,21 @@

logger = LoggerFactory().get_logger()

class ToolErrorCode(StrEnum):
HTTP_ERROR = "http_error"
AUTH = "auth"
OTHER = "other"

class ToolAuthException(Exception):
def __init__(self, message, tool_id: str):
self.message = message
self.tool_id = tool_id

class ToolError(BaseModel, extra="allow"):
type: ToolErrorCode = ToolErrorCode.OTHER
success: bool = False
text: str
details: str = ""

class BaseTool():
"""
Expand Down Expand Up @@ -48,6 +65,16 @@ def generate_error_message(cls) -> str | None:
@classmethod
def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None: ...

@classmethod
def get_tool_error(cls, err: ToolError):
tool_error = err.model_dump()
logger.error(event=f"Error calling tool {cls.ID}", error=tool_error)
return [tool_error]

@classmethod
def get_no_results_error(cls):
return cls.get_tool_error(ToolError(text="No results found."))

@abstractmethod
async def call(
self, parameters: dict, ctx: Any, **kwargs: Any
Expand Down Expand Up @@ -135,8 +162,3 @@ def delete_tool_auth(self, session: DBSessionDep, user_id: str) -> bool:
)
raise


class ToolAuthException(Exception):
def __init__(self, message, tool_id: str):
self.message = message
self.tool_id = tool_id
76 changes: 13 additions & 63 deletions src/backend/tools/brave_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from backend.config.settings import Settings
from backend.database_models.database import DBSessionDep
from backend.model_deployments.base import BaseDeployment
from backend.schemas.agent import AgentToolMetadataArtifactsType
from backend.schemas.tool import ToolCategory, ToolDefinition
from backend.tools.base import BaseTool
Expand Down Expand Up @@ -58,71 +57,22 @@ async def call(
AgentToolMetadataArtifactsType.DOMAIN, session, ctx
)

result = await self.client.search_async(
response = await self.client.search_async(
q=query, count=self.num_results, include_domains=filtered_domains
)
result = dict(result)
response = dict(response)

if "web" not in result and "results" not in result["web"]:
return []
results = response.get("web", {}).get("results", [])

transformed_results = []
for item in result["web"]["results"]:
new_result = {
"url": item["url"],
"title": item["title"],
"content": item["description"],
}
transformed_results.append(new_result)
if not results:
self.get_no_results_error()

reranked_results = await self.rerank_page_snippets(
query,
transformed_results,
model=kwargs.get("model_deployment"),
ctx=ctx,
**kwargs,
)

return [
{"url": result["url"], "text": result["content"], "title": result["title"]}
for result in reranked_results
]

async def rerank_page_snippets(
self,
query: str,
snippets: List[Dict[str, Any]],
model: BaseDeployment,
ctx: Any,
**kwargs: Any,
) -> List[Dict[str, Any]]:
if len(snippets) == 0:
return []

rerank_batch_size = 500
relevance_scores = [None for _ in range(len(snippets))]
for batch_start in range(0, len(snippets), rerank_batch_size):
snippet_batch = snippets[batch_start : batch_start + rerank_batch_size]
batch_output = await model.invoke_rerank(
query=query,
documents=[
f"{snippet['title']} {snippet['content']}"
for snippet in snippet_batch
],
ctx=ctx,
)
for b in batch_output.get("results", []):
index = b.get("index", None)
relevance_score = b.get("relevance_score", None)
if index is not None:
relevance_scores[batch_start + index] = relevance_score

reranked, seen_urls = [], []
for _, result in sorted(
zip(relevance_scores, snippets), key=lambda x: x[0], reverse=True
):
if result["url"] not in seen_urls:
seen_urls.append(result["url"])
reranked.append(result)
tool_results = []
for result in results:
tool_results.append({
"text": result.get("description"),
"title": result.get("title"),
"url": result.get("url"),
})

return reranked[: self.num_results]
return tool_results
18 changes: 11 additions & 7 deletions src/backend/tools/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,18 @@ async def call(
response = cse.list(q=query, cx=self.CSE_ID, orTerms=site_filters).execute()
search_results = response.get("items", [])

if not search_results:
return self.get_no_results_error()

tool_results = []
for result in search_results:
tool_result = {
"title": result["title"],
"url": result["link"],
}
if "snippet" in result:
tool_result["text"] = result["snippet"]
tool_results.append(tool_result)
if "snippet" not in result:
continue

tool_results.append({
"text": result.get("snippet"),
"title": result.get("title"),
"url": result.get("url")
})

return tool_results
9 changes: 5 additions & 4 deletions src/backend/tools/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class HybridWebSearch(BaseTool, WebSearchFilteringMixin):
ID = "hybrid_web_search"
POST_RERANK_MAX_RESULTS = 6
POST_RERANK_MAX_RESULTS = 5
AVAILABLE_WEB_SEARCH_TOOLS = [TavilyWebSearch, GoogleWebSearch, BraveWebSearch]
ENABLED_WEB_SEARCH_TOOLS = Settings().get('tools.hybrid_web_search.enabled_web_searches')
WEB_SCRAPE_TOOL = WebScrapeTool
Expand Down Expand Up @@ -159,8 +159,9 @@ async def rerank_results(
for _, result in sorted(
zip(relevance_scores, results), key=lambda x: x[0], reverse=True
):
if result["url"] not in seen_urls:
seen_urls.append(result["url"])
url = result.get("url")
if url not in seen_urls:
seen_urls.append(url)
reranked.append(result)

return reranked[: self.POST_RERANK_MAX_RESULTS]
return reranked[:self.POST_RERANK_MAX_RESULTS]
Loading

0 comments on commit 2a7ef6b

Please sign in to comment.