Skip to content

Commit

Permalink
Fix v1 api bug (#274)
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Nov 18, 2024
1 parent 6585d48 commit 2ec98f4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/pai_rag/app/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def aquery_retrieval(query: RetrievalQuery):

@router_v1.post("/query/agent")
async def aquery_agent(query: RagQuery):
response = await rag_service.aquery_agent(query)
response = await rag_service.aquery_agent_v1(query)
if not query.stream:
return response
else:
Expand Down Expand Up @@ -264,7 +264,7 @@ async def upload_datasheet(

@router_v1.post("/query/data_analysis")
async def aquery_analysis(query: RagQuery):
response = await rag_service.aquery_analysis(query)
response = await rag_service.aquery_analysis_v1(query)
if not query.stream:
return response
else:
Expand Down
14 changes: 9 additions & 5 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def aquery(
intent = await intent_router.aselect(str_or_query_bundle=new_question)
logger.info(f"[IntentDetection] Routing query to {intent}.")
if intent == Intents.TOOL:
return await self.aquery_agent(query)
return await self.aquery_agent(query, sse_version=sse_version)
elif intent == Intents.WEBSEARCH:
chat_type = RagChatType.WEB
elif intent == Intents.NL2SQL:
Expand Down Expand Up @@ -260,7 +260,9 @@ async def aquery(
sse_version=sse_version,
)

async def aquery_agent(self, query: RagQuery) -> RagResponse:
async def aquery_agent(
self, query: RagQuery, sse_version: SseVersion = SseVersion.V0
) -> RagResponse:
"""Query answer from RAG App via web search asynchronously.
Generate answer from agent's achat interface.
Expand All @@ -277,7 +279,7 @@ async def aquery_agent(self, query: RagQuery) -> RagResponse:
agent = resolve_agent(self.config)
if query.stream:
response = await agent.astream_chat(query.question)
return event_generator_async(response)
return event_generator_async(response, sse_version=sse_version)
else:
response = await agent.achat(query.question)
return RagResponse(answer=response.response)
Expand Down Expand Up @@ -306,7 +308,9 @@ async def aload_agent_config(self, agent_cfg_path: str):
else:
return f"The agent config path {agent_cfg_path} not exists."

async def aquery_analysis(self, query: RagQuery):
async def aquery_analysis(
self, query: RagQuery, sse_version: SseVersion = SseVersion.V0
):
"""Query answer from RAG App asynchronously.
Generate answer from Data Analysis interface.
Expand Down Expand Up @@ -361,4 +365,4 @@ async def aquery_analysis(self, query: RagQuery):
if not query.stream:
return RagResponse(answer=response.response, **result_info)
else:
return event_generator_async(response=response, extra_info=result_info)
return event_generator_async(response=response, sse_version=sse_version)
14 changes: 14 additions & 0 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,13 @@ async def aquery_agent(self, query: RagQuery) -> RagResponse:
logger.error(traceback.format_exc())
raise UserInputError(f"Query RAG Agent failed: {ex}")

async def aquery_agent_v1(self, query: RagQuery) -> RagResponse:
try:
return await self.rag.aquery_agent(query, sse_version=SseVersion.V1)
except Exception as ex:
logger.error(traceback.format_exc())
raise UserInputError(f"Query RAG Agent failed: {ex}")

async def aload_agent_config(self, agent_cfg_path: str):
try:
return await self.rag.aload_agent_config(agent_cfg_path)
Expand All @@ -194,5 +201,12 @@ async def aquery_analysis(self, query: RagQuery):
logger.error(traceback.format_exc())
raise UserInputError(f"Query Analysis failed: {ex}")

async def aquery_analysis_v1(self, query: RagQuery):
try:
return await self.rag.aquery_analysis(query, sse_version=SseVersion.V1)
except Exception as ex:
logger.error(traceback.format_exc())
raise UserInputError(f"Query Analysis failed: {ex}")


rag_service = RagService()

0 comments on commit 2ec98f4

Please sign in to comment.