Skip to content

Commit

Permalink
Fix empty response for score_threshold (#136)
Browse files Browse the repository at this point in the history
* Fix empty response for score_threshold

* Modify empty response info

* Modify empty response info

---------

Co-authored-by: Yue Fei <[email protected]>
  • Loading branch information
wwxxzz and moria97 authored Jul 31, 2024
1 parent cd4c0b8 commit 179d6b2
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 10 deletions.
15 changes: 7 additions & 8 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_evaluate_response_url(self):
return f"{self.endpoint}service/evaluate/response"

def _format_rag_response(
self, response, session_id: str = None, stream: bool = False
self, question, response, session_id: str = None, stream: bool = False
):
if stream:
text = response["delta"]
Expand All @@ -85,9 +85,8 @@ def _format_rag_response(
referenced_docs = ""
images = ""

# 空结果,TODO: 适配score_threshold的场景
if is_finished and len(docs) == 0 and not text:
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
for i, doc in enumerate(docs):
Expand Down Expand Up @@ -122,7 +121,7 @@ def query(self, text: str, session_id: str = None, stream: bool = False):
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
response, session_id=session_id, stream=stream
text, response, session_id=session_id, stream=stream
)
else:
full_content = ""
Expand All @@ -131,7 +130,7 @@ def query(self, text: str, session_id: str = None, stream: bool = False):
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
chunk_response, session_id=session_id, stream=stream
text, chunk_response, session_id=session_id, stream=stream
)

def query_llm(
Expand All @@ -156,7 +155,7 @@ def query_llm(
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
response, session_id=session_id, stream=stream
text, response, session_id=session_id, stream=stream
)
else:
full_content = ""
Expand All @@ -165,7 +164,7 @@ def query_llm(
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
chunk_response, session_id=session_id, stream=stream
text, chunk_response, session_id=session_id, stream=stream
)

def query_vector(self, text: str):
Expand All @@ -179,7 +178,7 @@ def query_vector(self, text: str):
"<tr><th>Document</th><th>Score</th><th>Text</th><th>Media</tr>\n"
)
if len(response["docs"]) == 0:
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=text)
else:
for i, doc in enumerate(response["docs"]):
html_content = markdown.markdown(doc["text"])
Expand Down
2 changes: 1 addition & 1 deletion src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@
],
}

EMPTY_KNOWLEDGEBASE_MESSAGE = "The knowledge base is empty. Kindly upload your knowledge files before executing a query."
EMPTY_KNOWLEDGEBASE_MESSAGE = "We couldn't find any documents related to your question: {query_str}. \n\n You may try lowering the similarity_threshold or uploading relevant knowledge files."
2 changes: 1 addition & 1 deletion src/pai_rag/modules/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
from pai_rag.utils.prompt_template import DEFAULT_TEXT_QA_PROMPT_TMPL
from pai_rag.integrations.synthesizer.my_synthesizer import MySimpleSummarize
from pai_rag.integrations.synthesizer.my_simple_synthesizer import MySimpleSummarize

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit 179d6b2

Please sign in to comment.