Skip to content

Commit

Permalink
Add web search (#161)
Browse files Browse the repository at this point in the history
* Add web search

* Fix lint

* Fix bug

* Update timeout

* Fix bug
  • Loading branch information
moria97 authored Aug 22, 2024
1 parent 95e4bf3 commit 49961e4
Show file tree
Hide file tree
Showing 16 changed files with 912 additions and 18 deletions.
496 changes: 494 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pgvector = "^0.3.2"
pre-commit = "^3.8.0"
cn-clip = "^1.5.1"
llama-index-llms-paieas = "^0.1.0"
llama-index-readers-web = "^0.1.23"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
1 change: 1 addition & 0 deletions pyproject_gpu.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ asyncpg = "^0.29.0"
pgvector = "^0.3.2"
pre-commit = "^3.8.0"
llama-index-llms-paieas = "^0.1.0"
llama-index-readers-web = "^0.1.23"

[tool.poetry.scripts]
pai_rag = "pai_rag.main:main"
Expand Down
12 changes: 12 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ async def aquery_llm(query: RagQuery):
)


@router.post("/query/search")
async def aquery_search(query: RagQuery):
response = await rag_service.aquery_search(query)
if not query.stream:
return response
else:
return StreamingResponse(
response,
media_type="text/event-stream",
)


@router.post("/query/retrieval")
async def aquery_retrieval(query: RetrievalQuery):
return await rag_service.aquery_retrieval(query)
Expand Down
43 changes: 39 additions & 4 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def set_endpoint(self, endpoint: str):
def query_url(self):
return f"{self.endpoint}service/query"

@property
def search_url(self):
return f"{self.endpoint}service/query/search"

@property
def llm_url(self):
return f"{self.endpoint}service/query/llm"
Expand Down Expand Up @@ -91,14 +95,20 @@ def _format_rag_response(
response["result"] = EMPTY_KNOWLEDGEBASE_MESSAGE.format(query_str=question)
return response
elif is_finished:
seen_filenames = set()
file_idx = 1
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
if filename:
if filename and filename not in seen_filenames:
seen_filenames.add(filename)
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
referenced_docs += (
f'[{i+1}]: {formatted_file_name} Score:{doc["score"]} \n'
)
title = doc["metadata"].get("title")
if not title:
referenced_docs += f'[{file_idx}]: {formatted_file_name} Score:{doc["score"]} \n'
else:
referenced_docs += f'[{file_idx}]: [{title}]({formatted_file_name}) Score:{doc["score"]} \n'

file_idx += 1
formatted_answer = ""
if session_id:
new_query = response["new_query"]
Expand Down Expand Up @@ -138,6 +148,31 @@ def query(
text, chunk_response, session_id=session_id, stream=stream
)

def query_search(
self,
text: str,
session_id: str = None,
stream: bool = False,
):
q = dict(question=text, session_id=session_id, stream=stream, with_intent=False)
r = requests.post(self.search_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
if not stream:
response = dotdict(json.loads(r.text))
yield self._format_rag_response(
text, response, session_id=session_id, stream=stream
)
else:
full_content = ""
for chunk in r.iter_lines(chunk_size=8192, decode_unicode=True):
chunk_response = dotdict(json.loads(chunk))
full_content += chunk_response.delta
chunk_response.delta = full_content
yield self._format_rag_response(
text, chunk_response, session_id=session_id, stream=stream
)

def query_llm(
self,
text: str,
Expand Down
75 changes: 67 additions & 8 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,38 @@ def respond(input_elements: List[Any]):
if not update_dict["include_history"]:
current_session_id = None

content = ""
chatbot.append((msg, content))

try:
if query_type == "LLM":
response_gen = rag_client.query_llm(
msg, session_id=current_session_id, stream=is_streaming
)
elif query_type == "Retrieval":
response_gen = rag_client.query_vector(msg)

elif query_type == "WebSearch":
response_gen = rag_client.query_search(
msg, session_id=current_session_id, stream=is_streaming
)
else:
response_gen = rag_client.query(
msg, session_id=current_session_id, stream=is_streaming
)
for resp in response_gen:
chatbot[-1] = (msg, resp.result)
yield chatbot

except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")

content = ""
chatbot.append((msg, content))
for resp in response_gen:
chatbot[-1] = (msg, resp.result)
yield chatbot


def create_chat_tab() -> Dict[str, Any]:
with gr.Row():
with gr.Column(scale=2):
query_type = gr.Radio(
["Retrieval", "LLM", "RAG (Retrieval + LLM)"],
["Retrieval", "LLM", "WebSearch", "RAG (Retrieval + LLM)"],
label="\N{fire} Which query do you want to use?",
elem_id="query_type",
value="RAG (Retrieval + LLM)",
Expand Down Expand Up @@ -218,6 +223,32 @@ def change_retrieval_mode(retrieval_mode):
)
llm_args = {llm_temp, include_history}

with gr.Column(visible=True) as search_col:
search_model_argument = gr.Accordion(
"Parameters of Web Search", open=False
)
with search_model_argument:
search_api_key = gr.Text(
label="Bing API Key",
value="",
type="password",
elem_id="search_api_key",
)
search_count = gr.Slider(
label="Search Count",
minimum=5,
maximum=30,
step=1,
elem_id="search_count",
)
search_lang = gr.Radio(
label="Language",
choices=["zh-CN", "en-US"],
value="zh-CN",
elem_id="search_lang",
)
search_args = {search_api_key, search_count, search_lang}

with gr.Column(visible=True) as lc_col:
prm_type = gr.Radio(
[
Expand Down Expand Up @@ -283,6 +314,8 @@ def change_query_radio(query_type):
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=True),
search_model_argument: gr.update(open=False),
search_col: gr.update(visible=False),
llm_col: gr.update(visible=False),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=False),
Expand All @@ -291,14 +324,28 @@ def change_query_radio(query_type):
return {
vs_col: gr.update(visible=False),
vec_model_argument: gr.update(open=False),
search_model_argument: gr.update(open=False),
search_col: gr.update(visible=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=True),
lc_col: gr.update(visible=False),
}
elif query_type == "WebSearch":
return {
vs_col: gr.update(visible=False),
vec_model_argument: gr.update(open=False),
search_model_argument: gr.update(open=True),
search_col: gr.update(visible=True),
llm_col: gr.update(visible=False),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=False),
}
elif query_type == "RAG (Retrieval + LLM)":
return {
vs_col: gr.update(visible=True),
vec_model_argument: gr.update(open=False),
search_model_argument: gr.update(open=False),
search_col: gr.update(visible=False),
llm_col: gr.update(visible=True),
model_argument: gr.update(open=False),
lc_col: gr.update(visible=True),
Expand All @@ -307,7 +354,15 @@ def change_query_radio(query_type):
query_type.input(
fn=change_query_radio,
inputs=query_type,
outputs=[vs_col, vec_model_argument, llm_col, model_argument, lc_col],
outputs=[
vs_col,
vec_model_argument,
search_model_argument,
search_col,
llm_col,
model_argument,
lc_col,
],
)

with gr.Column(scale=8):
Expand All @@ -321,6 +376,7 @@ def change_query_radio(query_type):
{text_qa_template, question, query_type, chatbot, is_streaming}
.union(vec_args)
.union(llm_args)
.union(search_args)
)

submitBtn.click(
Expand Down Expand Up @@ -359,4 +415,7 @@ def change_query_radio(query_type):
similarity_threshold.elem_id: similarity_threshold,
prm_type.elem_id: prm_type,
text_qa_template.elem_id: text_qa_template,
search_lang.elem_id: search_lang,
search_api_key.elem_id: search_api_key,
search_count.elem_id: search_count,
}
22 changes: 22 additions & 0 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ class ViewModel(BaseModel):
retrieval_mode: str = "hybrid" # hybrid / embedding / keyword
query_rewrite_n: int = 1

# websearch
search_api_key: str = None
search_count: int = 10
search_lang: str = "zh-CN"

# postprocessor
reranker_type: str = (
"simple-weighted-reranker" # simple-weighted-reranker / model-based-reranker
Expand Down Expand Up @@ -297,6 +302,13 @@ def from_app_config(config):
"text_qa_template", None
)

search_config = config.get("search") or {}
view_model.search_api_key = search_config.get(
"search_api_key"
) or os.environ.get("BING_SEARCH_KEY")
view_model.search_lang = search_config.get("search_lang", "zh-CN")
view_model.search_count = search_config.get("search_count", 10)

return view_model

def to_app_config(self):
Expand Down Expand Up @@ -406,6 +418,12 @@ def to_app_config(self):
config["synthesizer"]["type"] = self.synthesizer_type
config["synthesizer"]["text_qa_template"] = self.text_qa_template

config["search"]["search_api_key"] = self.search_api_key or os.environ.get(
"BING_SEARCH_KEY"
)
config["search"]["search_lang"] = self.search_lang
config["search"]["search_count"] = self.search_count

return _transform_to_dict(config)

def get_local_generated_qa_file(self):
Expand Down Expand Up @@ -585,4 +603,8 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:
"value": self.get_local_evaluation_result_file(type="response")
}

# search
settings["search_api_key"] = {"value": self.search_api_key}
settings["search_lang"] = {"value": self.search_lang}
settings["search_count"] = {"value": self.search_count}
return settings
3 changes: 3 additions & 0 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ similarity_top_k = 3
retrieval_mode = "hybrid" # [hybrid, embedding, keyword, router]
query_rewrite_n = 1 # set to 1 to disable query generation

[rag.search]
search_api_key = ""

[rag.synthesizer]
type = "SimpleSummarize"
text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"
54 changes: 54 additions & 0 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,59 @@ async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:

return RetrievalResponse(docs=docs)

async def aquery_search(self, query: RagQuery):
"""Query answer from RAG App asynchronously.
Generate answer from Query Engine's or Chat Engine's achat interface.
Args:
query: RagQuery
Returns:
RagResponse
"""
session_id = query.session_id or uuid_generator()
self.logger.debug(f"Get session ID: {session_id}.")
if not query.question:
return RagResponse(
answer="Empty query. Please input your question.", session_id=session_id
)

sessioned_config = self.config

searcher = module_registry.get_module_with_config(
"SearchModule", sessioned_config
)
if not searcher:
raise ValueError("AI search not enabled. Please add search API key.")
if not query.stream:
response = await searcher.aquery(query.question)
else:
response = await searcher.astream_query(query.question)

node_results = response.source_nodes
new_query = query.question

reference_docs = [
ContextDoc(
text=score_node.node.get_content(),
metadata=score_node.node.metadata,
score=score_node.score,
)
for score_node in node_results
]

result_info = {
"session_id": session_id,
"docs": reference_docs,
"new_query": new_query,
}

if not query.stream:
return RagResponse(answer=response.response, **result_info)
else:
return event_generator_async(response=response, extra_info=result_info)

async def aquery_rag(self, query: RagQuery):
"""Query answer from RAG App asynchronously.
Expand All @@ -122,6 +175,7 @@ async def aquery_rag(self, query: RagQuery):
)

sessioned_config = self.config

if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config.rag.index.update(
Expand Down
8 changes: 8 additions & 0 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ async def aquery(self, query: RagQuery):
logger.error(traceback.format_exc())
raise UserInputError(f"Query RAG failed: {ex}")

async def aquery_search(self, query: RagQuery):
try:
self.check_updates()
return await self.rag.aquery_search(query)
except Exception as ex:
logger.error(traceback.format_exc())
raise UserInputError(f"Query Search failed: {ex}")

async def aquery_llm(self, query: RagQuery):
try:
self.check_updates()
Expand Down
Loading

0 comments on commit 49961e4

Please sign in to comment.