Skip to content

Commit

Permalink
support langchain new version and set agent False as default
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 committed Oct 26, 2023
1 parent 35d18cd commit 07f0a15
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 25 deletions.
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import uvicorn
from functools import partial
from fastapi import FastAPI, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse
Expand All @@ -17,6 +18,9 @@
parser.add_argument('--langchain', action='store_true')
parser.add_argument('--towhee', action='store_true')
parser.add_argument('--moniter', action='store_true')
parser.add_argument('--agent', action='store_true',
help='The default is False, which only works when `--langchain` is enabled.'
' It means using the agent in langchain to dynamically select tools.')
parser.add_argument('--max_observation', default=1000)
parser.add_argument('--name', default=str(uuid.uuid4()))
args = parser.parse_args()
Expand All @@ -29,13 +33,15 @@
USE_TOWHEE = args.towhee
MAX_OBSERVATION = args.max_observation
ENABLE_MONITER = args.moniter
ENABLE_AGENT = args.agent
NAME = args.name

assert (USE_LANGCHAIN and not USE_TOWHEE ) or (USE_TOWHEE and not USE_LANGCHAIN), \
'The service should start with either "--langchain" or "--towhee".'

if USE_LANGCHAIN:
from src_langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
chat = partial(chat, enable_agent=ENABLE_AGENT)
if USE_TOWHEE:
from src_towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
if ENABLE_MONITER:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
langchain==0.0.230
langchain
unstructured
pexpect
pdf2image
Expand Down
6 changes: 3 additions & 3 deletions src_langchain/agent/chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, List, Optional, Sequence
from pydantic import Field

from langchain.agents.conversational_chat.prompt import PREFIX, SUFFIX
from pydantic import Field
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.agents import ConversationalChatAgent, AgentOutputParser, Agent
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseOutputParser, BaseLanguageModel
from langchain.schema import BaseOutputParser
from langchain.tools.base import BaseTool

from .output_parser import OutputParser
Expand Down
2 changes: 1 addition & 1 deletion src_langchain/data_loader/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __call__(self, data_src, source_type: str = 'file') -> List[str]:
token_count += len(self.enc.encode(doc))
return docs, token_count

def from_files(self, files: list, encoding: Optional[str] = None) -> List[Document]:
def from_files(self, files: list, encoding: Optional[str] = 'utf-8') -> List[Document]:
'''Load documents from path or file-like object, return a list of unsplit LangChain Documents'''
docs = []
for file in files:
Expand Down
49 changes: 29 additions & 20 deletions src_langchain/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from langchain.agents import Tool, AgentExecutor
from langchain.chains import ConversationalRetrievalChain

sys.path.append(os.path.dirname(__file__))

Expand All @@ -13,41 +14,49 @@
from store import MemoryStore, DocStore # pylint: disable=C0413
from data_loader import DataParser # pylint: disable=C0413


logger = logging.getLogger(__name__)

encoder = TextEncoder()
chat_llm = ChatLLM()
load_data = DataParser()


def chat(session_id, project, question):
def chat(session_id, project, question, enable_agent=False):
'''Chat API'''
doc_db = DocStore(
table_name=project,
embedding_func=encoder,
)
memory_db = MemoryStore(table_name=project, session_id=session_id)

tools = [
Tool(
name='Search',
func=doc_db.search,
description='Search through Milvus.'
if enable_agent: # use agent
tools = [
Tool(
name='Search',
func=doc_db.search,
description='Search through Milvus.'
)
]
agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools)
agent_chain = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
memory=memory_db.memory,
verbose=False
)
]
agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools)
agent_chain = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
memory=memory_db.memory,
verbose=False
)
try:
final_answer = agent_chain.run(input=question)
try:
final_answer = agent_chain.run(input=question)
return question, final_answer
except Exception as e: # pylint: disable=W0703
return question, f'Something went wrong:\n{e}'
else: # use chain
qa = ConversationalRetrievalChain.from_llm(
llm=chat_llm,
retriever=doc_db.vector_db.as_retriever(),
memory=memory_db.memory,
)
final_answer = qa(question)['answer']
return question, final_answer
except Exception as e: # pylint: disable=W0703
return question, f'Something went wrong:\n{e}'


def insert(data_src, project, source_type: str = 'file'):
Expand Down Expand Up @@ -93,6 +102,7 @@ def check(project):
raise RuntimeError from e
return {'store': doc_check, 'memory': memory_check}


def count(project):
'''Count entities.'''
try:
Expand Down Expand Up @@ -130,7 +140,6 @@ def load(document_strs: List[str], project: str):
num = doc_db.insert(document_strs)
return num


# if __name__ == '__main__':
# project = 'akcio'
# data_src = 'https://docs.towhee.io/'
Expand Down

0 comments on commit 07f0a15

Please sign in to comment.