Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support langchain new version and set agent False as default #90

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid

import uvicorn
from functools import partial
from fastapi import FastAPI, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi.responses import PlainTextResponse
Expand All @@ -17,6 +18,9 @@
parser.add_argument('--langchain', action='store_true')
parser.add_argument('--towhee', action='store_true')
parser.add_argument('--moniter', action='store_true')
parser.add_argument('--agent', action='store_true',
help='The default is False, which only works when `--langchain` is enabled.'
' It means using the agent in langchain to dynamically select tools.')
parser.add_argument('--max_observation', default=1000)
parser.add_argument('--name', default=str(uuid.uuid4()))
args = parser.parse_args()
Expand All @@ -29,13 +33,15 @@
USE_TOWHEE = args.towhee
MAX_OBSERVATION = args.max_observation
ENABLE_MONITER = args.moniter
ENABLE_AGENT = args.agent
NAME = args.name

assert (USE_LANGCHAIN and not USE_TOWHEE ) or (USE_TOWHEE and not USE_LANGCHAIN), \
'The service should start with either "--langchain" or "--towhee".'

if USE_LANGCHAIN:
from src_langchain.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
chat = partial(chat, enable_agent=ENABLE_AGENT)
if USE_TOWHEE:
from src_towhee.operations import chat, insert, drop, check, get_history, clear_history, count # pylint: disable=C0413
if ENABLE_MONITER:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
langchain==0.0.230
langchain==0.0.322
unstructured
pexpect
pdf2image
Expand Down
6 changes: 3 additions & 3 deletions src_langchain/agent/chat_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, List, Optional, Sequence
from pydantic import Field

from langchain.agents.conversational_chat.prompt import PREFIX, SUFFIX
from pydantic import Field
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.agents import ConversationalChatAgent, AgentOutputParser, Agent
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseOutputParser, BaseLanguageModel
from langchain.schema import BaseOutputParser
from langchain.tools.base import BaseTool

from .output_parser import OutputParser
Expand Down
2 changes: 1 addition & 1 deletion src_langchain/data_loader/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __call__(self, data_src, source_type: str = 'file') -> List[str]:
token_count += len(self.enc.encode(doc))
return docs, token_count

def from_files(self, files: list, encoding: Optional[str] = None) -> List[Document]:
def from_files(self, files: list, encoding: Optional[str] = 'utf-8') -> List[Document]:
'''Load documents from path or file-like object, return a list of unsplit LangChain Documents'''
docs = []
for file in files:
Expand Down
54 changes: 33 additions & 21 deletions src_langchain/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

from langchain.agents import Tool, AgentExecutor
from langchain.chains import ConversationalRetrievalChain

sys.path.append(os.path.dirname(__file__))

Expand All @@ -13,41 +14,52 @@
from store import MemoryStore, DocStore # pylint: disable=C0413
from data_loader import DataParser # pylint: disable=C0413


logger = logging.getLogger(__name__)

encoder = TextEncoder()
chat_llm = ChatLLM()
load_data = DataParser()


def chat(session_id, project, question):
def chat(session_id, project, question, enable_agent=False):
'''Chat API'''
doc_db = DocStore(
table_name=project,
embedding_func=encoder,
)
memory_db = MemoryStore(table_name=project, session_id=session_id)

tools = [
Tool(
name='Search',
func=doc_db.search,
description='Search through Milvus.'
if enable_agent: # use agent
memory_db.memory.output_key = None
tools = [
Tool(
name='Search',
func=doc_db.search,
description='useful for search professional knowledge and information'
)
]
agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools)
agent_chain = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
memory=memory_db.memory,
verbose=False
)
]
agent = ChatAgent.from_llm_and_tools(llm=chat_llm, tools=tools)
agent_chain = AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
memory=memory_db.memory,
verbose=False
)
try:
final_answer = agent_chain.run(input=question)
return question, final_answer
except Exception as e: # pylint: disable=W0703
return question, f'Something went wrong:\n{e}'
try:
final_answer = agent_chain.run(input=question)
return question, final_answer
except Exception as e: # pylint: disable=W0703
return question, f'Something went wrong:\n{e}'
else: # use chain
memory_db.memory.output_key = 'answer'
qa = ConversationalRetrievalChain.from_llm(
llm=chat_llm,
retriever=doc_db.vector_db.as_retriever(),
memory=memory_db.memory,
return_generated_question=True
)
qa_result = qa(question)
return qa_result['generated_question'], qa_result['answer']


def insert(data_src, project, source_type: str = 'file'):
Expand Down Expand Up @@ -93,6 +105,7 @@ def check(project):
raise RuntimeError from e
return {'store': doc_check, 'memory': memory_check}


def count(project):
'''Count entities.'''
try:
Expand Down Expand Up @@ -130,7 +143,6 @@ def load(document_strs: List[str], project: str):
num = doc_db.insert(document_strs)
return num


# if __name__ == '__main__':
# project = 'akcio'
# data_src = 'https://docs.towhee.io/'
Expand Down
Loading