Skip to content

Commit

Permalink
Fix chat with history (#159)
Browse files Browse the repository at this point in the history
* Hotfix

* Fix lint

* Fix lint

* Fix chat_history

* Fix lint
  • Loading branch information
moria97 authored Aug 21, 2024
1 parent 1ad6dcf commit 1af37d2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
24 changes: 16 additions & 8 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,14 @@ class dotdict(dict):
class RagWebClient:
def __init__(self):
self.endpoint = "http://127.0.0.1:8000/" # default link
self.session_id = None

def set_endpoint(self, endpoint: str):
self.endpoint = endpoint if endpoint.endswith("/") else f"{endpoint}/"

def clear_history(self):
self.session_id = None

@property
def query_url(self):
return f"{self.endpoint}service/query"
Expand Down Expand Up @@ -72,14 +76,15 @@ def get_evaluate_response_url(self):
return f"{self.endpoint}service/evaluate/response"

def _format_rag_response(
self, question, response, session_id: str = None, stream: bool = False
self, question, response, with_history: bool = False, stream: bool = False
):
if stream:
text = response["delta"]
else:
text = response["answer"]

docs = response.get("docs", [])
session_id = response.get("session_id", None)
is_finished = response.get("is_finished", True)

referenced_docs = ""
Expand All @@ -89,6 +94,7 @@ def _format_rag_response(
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
self.session_id = session_id
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
if filename:
Expand All @@ -101,7 +107,7 @@ def _format_rag_response(
images += f"""<img src="{image_url}"/>"""

formatted_answer = ""
if session_id:
if with_history and "new_query" in response:
new_query = response["new_query"]
formatted_answer += f"**Query Transformation**: {new_query} \n\n"
formatted_answer += f"**Answer**: {text} \n\n"
Expand All @@ -113,15 +119,16 @@ def _format_rag_response(
response["result"] = formatted_answer
return response

def query(self, text: str, session_id: str = None, stream: bool = False):
def query(self, text: str, with_history: bool = False, stream: bool = False):
session_id = self.session_id if with_history else None
q = dict(question=text, session_id=session_id, stream=stream)
r = requests.post(self.query_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
text, response, session_id=session_id, stream=stream
text, response, with_history=with_history, stream=stream
)
else:
full_content = ""
Expand All @@ -130,16 +137,17 @@ def query(self, text: str, session_id: str = None, stream: bool = False):
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
text, chunk_response, session_id=session_id, stream=stream
text, chunk_response, with_history=with_history, stream=stream
)

def query_llm(
self,
text: str,
session_id: str = None,
with_history: bool = False,
temperature: float = 0.1,
stream: bool = False,
):
session_id = self.session_id if with_history else None
q = dict(
question=text,
temperature=temperature,
Expand All @@ -155,7 +163,7 @@ def query_llm(
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
text, response, session_id=session_id, stream=stream
text, response, with_history=with_history, stream=stream
)
else:
full_content = ""
Expand All @@ -164,7 +172,7 @@ def query_llm(
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
text, chunk_response, session_id=session_id, stream=stream
text, chunk_response, with_history=with_history, stream=stream
)

def query_vector(self, text: str):
Expand Down
15 changes: 3 additions & 12 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
ACCURATE_CONTENT_PROMPTS,
)

current_session_id = None


def clear_history(chatbot):
rag_client.clear_history()
chatbot = []
global current_session_id
current_session_id = None
return chatbot, 0


Expand All @@ -23,7 +20,6 @@ def reset_textbox():


def respond(input_elements: List[Any]):
global current_session_id
update_dict = {}

for element, value in input_elements.items():
Expand All @@ -44,19 +40,16 @@ def respond(input_elements: List[Any]):
chatbot = update_dict["chatbot"]
is_streaming = update_dict["is_streaming"]

if not update_dict["include_history"]:
current_session_id = None

try:
if query_type == "LLM":
response_gen = rag_client.query_llm(
msg, session_id=current_session_id, stream=is_streaming
msg, with_history=update_dict["include_history"], stream=is_streaming
)
elif query_type == "Retrieval":
response_gen = rag_client.query_vector(msg)
else:
response_gen = rag_client.query(
msg, session_id=current_session_id, stream=is_streaming
msg, with_history=update_dict["include_history"], stream=is_streaming
)

except RagApiError as api_error:
Expand Down Expand Up @@ -282,8 +275,6 @@ def change_prompt_template(prm_type):
)

def change_query_radio(query_type):
global current_session_id
current_session_id = None
if query_type == "Retrieval":
return {
vs_col: gr.update(visible=True),
Expand Down
3 changes: 2 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ async def aquery_llm(self, query: LlmQuery):
return LlmResponse(answer=response.response, session_id=session_id)
else:
response = await llm_chat_engine.astream_chat(query.question)
return event_generator_async(response=response)
result_info = {"session_id": session_id}
return event_generator_async(response=response, extra_info=result_info)

async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
"""Query answer from RAG App via web search asynchronously.
Expand Down

0 comments on commit 1af37d2

Please sign in to comment.