From 1af37d225ca2ce64e3d3dd17ad611f6c942f2d49 Mon Sep 17 00:00:00 2001 From: Yue Fei Date: Wed, 21 Aug 2024 10:57:30 +0800 Subject: [PATCH] Fix chat with history (#159) * Hotfix * Fix lint * Fix lint * Fix chat_history * Fix lint --- src/pai_rag/app/web/rag_client.py | 24 ++++++++++++++++-------- src/pai_rag/app/web/tabs/chat_tab.py | 15 +++------------ src/pai_rag/core/rag_application.py | 3 ++- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 63d39ee2..a4c582fd 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -31,10 +31,14 @@ class dotdict(dict): class RagWebClient: def __init__(self): self.endpoint = "http://127.0.0.1:8000/" # default link + self.session_id = None def set_endpoint(self, endpoint: str): self.endpoint = endpoint if endpoint.endswith("/") else f"{endpoint}/" + def clear_history(self): + self.session_id = None + @property def query_url(self): return f"{self.endpoint}service/query" @@ -72,7 +76,7 @@ def get_evaluate_response_url(self): return f"{self.endpoint}service/evaluate/response" def _format_rag_response( - self, question, response, session_id: str = None, stream: bool = False + self, question, response, with_history: bool = False, stream: bool = False ): if stream: text = response["delta"] @@ -80,6 +84,7 @@ def _format_rag_response( text = response["answer"] docs = response.get("docs", []) + session_id = response.get("session_id", None) is_finished = response.get("is_finished", True) referenced_docs = "" @@ -89,6 +94,7 @@ def _format_rag_response( response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question) return response elif is_finished: + self.session_id = session_id for i, doc in enumerate(docs): filename = doc["metadata"].get("file_name", None) if filename: @@ -101,7 +107,7 @@ def _format_rag_response( images += f"""""" formatted_answer = "" - if session_id: + if with_history and "new_query" in response: new_query = response["new_query"] formatted_answer += f"**Query Transformation**: {new_query} \n\n" formatted_answer += f"**Answer**: {text} \n\n" @@ -113,7 +119,8 @@ def _format_rag_response( response["result"] = formatted_answer return response - def query(self, text: str, session_id: str = None, stream: bool = False): + def query(self, text: str, with_history: bool = False, stream: bool = False): + session_id = self.session_id if with_history else None q = dict(question=text, session_id=session_id, stream=stream) r = requests.post(self.query_url, json=q, stream=True) if r.status_code != HTTPStatus.OK: @@ -121,7 +128,7 @@ def query(self, text: str, session_id: str = None, stream: bool = False): if not stream: response = dotdict(json.loads(r.text)) yield self._format_rag_response( - text, response, session_id=session_id, stream=stream + text, response, with_history=with_history, stream=stream ) else: full_content = "" @@ -130,16 +137,17 @@ def query(self, text: str, session_id: str = None, stream: bool = False): full_content += chunk_response.delta chunk_response.delta = full_content yield self._format_rag_response( - text, chunk_response, session_id=session_id, stream=stream + text, chunk_response, with_history=with_history, stream=stream ) def query_llm( self, text: str, - session_id: str = None, + with_history: bool = False, temperature: float = 0.1, stream: bool = False, ): + session_id = self.session_id if with_history else None q = dict( question=text, temperature=temperature, @@ -155,7 +163,7 @@ def query_llm( if not stream: response = dotdict(json.loads(r.text)) yield self._format_rag_response( - text, response, session_id=session_id, stream=stream + text, response, with_history=with_history, stream=stream ) else: full_content = "" @@ -164,7 +172,7 @@ def query_llm( full_content += chunk_response.delta chunk_response.delta = full_content yield self._format_rag_response( - text, chunk_response, session_id=session_id, stream=stream + text, chunk_response, with_history=with_history, stream=stream ) def query_vector(self, text: str): diff --git a/src/pai_rag/app/web/tabs/chat_tab.py b/src/pai_rag/app/web/tabs/chat_tab.py index fd2926be..ad75e030 100644 --- a/src/pai_rag/app/web/tabs/chat_tab.py +++ b/src/pai_rag/app/web/tabs/chat_tab.py @@ -8,13 +8,10 @@ ACCURATE_CONTENT_PROMPTS, ) -current_session_id = None - def clear_history(chatbot): + rag_client.clear_history() chatbot = [] - global current_session_id - current_session_id = None return chatbot, 0 @@ -23,7 +20,6 @@ def reset_textbox(): def respond(input_elements: List[Any]): - global current_session_id update_dict = {} for element, value in input_elements.items(): @@ -44,19 +40,16 @@ def respond(input_elements: List[Any]): chatbot = update_dict["chatbot"] is_streaming = update_dict["is_streaming"] - if not update_dict["include_history"]: - current_session_id = None - try: if query_type == "LLM": response_gen = rag_client.query_llm( - msg, session_id=current_session_id, stream=is_streaming + msg, with_history=update_dict["include_history"], stream=is_streaming ) elif query_type == "Retrieval": response_gen = rag_client.query_vector(msg) else: response_gen = rag_client.query( - msg, session_id=current_session_id, stream=is_streaming + msg, with_history=update_dict["include_history"], stream=is_streaming ) except RagApiError as api_error: @@ -282,8 +275,6 @@ def change_prompt_template(prm_type): ) def change_query_radio(query_type): - global current_session_id - current_session_id = None if query_type == "Retrieval": return { vs_col: gr.update(visible=True), diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index e3d0fe88..ee4324d8 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -188,7 +188,8 @@ async def aquery_llm(self, query: LlmQuery): return LlmResponse(answer=response.response, session_id=session_id) else: response = await llm_chat_engine.astream_chat(query.question) - return event_generator_async(response=response) + result_info = {"session_id": session_id} + return event_generator_async(response=response, extra_info=result_info) async def aquery_agent(self, query: LlmQuery) -> LlmResponse: """Query answer from RAG App via web search asynchronously.