From 4b43831fe6e4fbb25c30bcbdcaf37e18035a6cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Tue, 26 Nov 2024 20:23:15 +0800 Subject: [PATCH 1/6] Add custom query api --- src/pai_rag/app/api/query.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index e42bbafc..94f0a27e 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -7,6 +7,7 @@ import tempfile import shutil import pandas as pd +import json from pai_rag.core.models.errors import UserInputError from pai_rag.core.rag_index_manager import RagIndexEntry, index_manager from pai_rag.core.rag_service import rag_service @@ -272,3 +273,20 @@ async def aquery_analysis(query: RagQuery): response, media_type="text/event-stream", ) + + +@router.post("/query/custom_test") +async def aquery_custom_test(query: RagQuery): + response = await rag_service.aquery_llm(query) + answer = json.loads(response.answer) + input_list = [res["型号"] for res in answer] + unique_input_list = list(set(input_list)) + # TODO + response = await rag_service.aquery_analysis(unique_input_list) + if not query.stream: + return response + else: + return StreamingResponse( + response, + media_type="text/event-stream", + ) From 15e4eb0c2a1d4aea2587d7330837ed34eb945139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Wed, 27 Nov 2024 09:59:43 +0800 Subject: [PATCH 2/6] Add more log info and error --- src/pai_rag/app/api/query.py | 43 ++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 94f0a27e..967a51be 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -275,18 +275,33 @@ async def aquery_analysis(query: RagQuery): ) -@router.post("/query/custom_test") +@router.post("/query/custom_search") async def aquery_custom_test(query: RagQuery): - response = await rag_service.aquery_llm(query) - answer = json.loads(response.answer) - input_list = [res["型号"] for res in answer] - unique_input_list = list(set(input_list)) - # TODO - response = await rag_service.aquery_analysis(unique_input_list) - if not query.stream: - return response - else: - return StreamingResponse( - response, - media_type="text/event-stream", - ) + try: + response = await rag_service.aquery_llm(query) + + try: + answer = json.loads(response.answer) + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON: {e}") + return "Parsing Error: The LLM response is not a valid JSON format." + + input_list = [res.get("型号") for res in answer if "型号" in res] + logger.info(f"Extracted input list: {input_list}") + if not input_list: + logger.warning("No model information found in response.") + return "Parsing Error: The '型号' key is not found in the JSON." + + unique_input_list = list(set(input_list)) + logger.info(f"Unique input list: {unique_input_list}") + # TODO + try: + sql_response = await rag_service.aquery_analysis(unique_input_list) + return sql_response + except Exception as e: + logger.error(f"SQL query failed: {e}") + return "SQL query failed: No information found for the relevant input list." + + except Exception as e: + logger.error(f"Unexpected error: {e}") + return "Unexpected error, please try again later." From 9fc1d4b9e8338c9c412f5cb24452d8ea3201e437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Wed, 27 Nov 2024 10:58:32 +0800 Subject: [PATCH 3/6] Return valid json data --- src/pai_rag/app/api/query.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 967a51be..ce32def6 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -284,24 +284,45 @@ async def aquery_custom_test(query: RagQuery): answer = json.loads(response.answer) except json.JSONDecodeError as e: logger.error(f"Error decoding JSON: {e}") - return "Parsing Error: The LLM response is not a valid JSON format." + return { + "status": "error", + "message": "Parsing Error: The LLM response is not a valid JSON format.", + "status_code": 400, + } input_list = [res.get("型号") for res in answer if "型号" in res] logger.info(f"Extracted input list: {input_list}") if not input_list: logger.warning("No model information found in response.") - return "Parsing Error: The '型号' key is not found in the JSON." + return { + "status": "error", + "message": "Parsing Error: The '型号' key is not found in the JSON.", + "status_code": 404, + } unique_input_list = list(set(input_list)) logger.info(f"Unique input list: {unique_input_list}") + # TODO try: sql_response = await rag_service.aquery_analysis(unique_input_list) - return sql_response + return { + "status": "success", + "data": sql_response.answer, + "status_code": 200, + } except Exception as e: logger.error(f"SQL query failed: {e}") - return "SQL query failed: No information found for the relevant input list." + return { + "status": "error", + "message": "SQL query failed: No information found for the relevant input list.", + "status_code": 500, + } except Exception as e: logger.error(f"Unexpected error: {e}") - return "Unexpected error, please try again later." + return { + "status": "error", + "message": "Unexpected error, please try again later.", + "status_code": 500, + } From 1a7ed9c83179dacb04dddb64933a4f911e3a65c0 Mon Sep 17 00:00:00 2001 From: chuyu Date: Wed, 27 Nov 2024 14:03:04 +0800 Subject: [PATCH 4/6] add sql_query --- src/pai_rag/app/api/query.py | 4 +- src/pai_rag/core/rag_application.py | 15 +++++ src/pai_rag/core/rag_service.py | 7 +++ .../data_analysis/data_analysis_tool.py | 26 +++++++++ src/pai_rag/tools/data_analysis_tool.py | 58 +++++++++++++++++++ 5 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index ce32def6..432445e1 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -305,10 +305,10 @@ async def aquery_custom_test(query: RagQuery): # TODO try: - sql_response = await rag_service.aquery_analysis(unique_input_list) + sql_response = rag_service.sql_query(unique_input_list) return { "status": "success", - "data": sql_response.answer, + "data": sql_response, "status_code": 200, } except Exception as e: diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 56257bf0..f50814b5 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -366,3 +366,18 @@ async def aquery_analysis( return RagResponse(answer=response.response, **result_info) else: return event_generator_async(response=response, sse_version=sse_version) + + + def sql_query( + self, input_list: list, sse_version: SseVersion = SseVersion.V0 + ): + # session_id = query.session_id or uuid_generator() + # logger.debug(f"Get session ID: {session_id}.") + + analysis_tool = resolve_data_analysis_tool(self.config) + if not analysis_tool: + raise ValueError("Data Analysis not enabled. Please specify analysis type.") + + result = analysis_tool.sql_query(input_list) + + return result \ No newline at end of file diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 69a3d978..8b609e4f 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -207,6 +207,13 @@ async def aquery_analysis_v1(self, query: RagQuery): except Exception as ex: logger.error(traceback.format_exc()) raise UserInputError(f"Query Analysis failed: {ex}") + + def sql_query(self, input_list: List): + try: + return self.rag.sql_query(input_list, sse_version=SseVersion.V1) + except Exception as ex: + logger.error(traceback.format_exc()) + raise UserInputError(f"SQL Query failed: {ex}") rag_service = RagService() diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py index cd109b56..1370fecc 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py @@ -12,6 +12,8 @@ from llama_index.core.settings import Settings import llama_index.core.instrumentation as instrument +from sqlalchemy import text + from pai_rag.integrations.data_analysis.nl2sql_retriever import MyNLSQLRetriever from pai_rag.integrations.data_analysis.data_analysis_config import ( BaseAnalysisConfig, @@ -154,3 +156,27 @@ async def astream_query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: self._synthesizer._streaming = streaming return stream_response + + + def sql_query(self, input_list: List) -> List[dict]: + """Query the material database directly.""" + table_name = self._retriever._tables[0] + print("table:", table_name) + columns = [item["name"] for item in self._retriever._sql_database.get_table_columns(table_name)] + print("columns:", columns) + # 使用字符串格式化生成值列表 + value_list = ", ".join(f""" "{code}" """.strip() for code in input_list) + # 构建 SQL 查询 + sql = f"SELECT * FROM material_data WHERE 物料编码 IN ({value_list})" + print("sql:", sql) + try: + with self._retriever._sql_database.engine.connect() as connection: + result = connection.execution_options(timeout=60).execute(text(sql)) + query_results = result.fetchall() + result_json = [dict(zip(columns, sublist)) for sublist in query_results] + return result_json + except NotImplementedError as error: + raise NotImplementedError(f"SQL execution not implemented: {error}") from error + except Exception as error: + raise Exception(f"Unexpected error during SQL execution: {error}") from error + diff --git a/src/pai_rag/tools/data_analysis_tool.py b/src/pai_rag/tools/data_analysis_tool.py index e69de29b..285d281f 100644 --- a/src/pai_rag/tools/data_analysis_tool.py +++ b/src/pai_rag/tools/data_analysis_tool.py @@ -0,0 +1,58 @@ +import click +import os +import time +import sys +from pathlib import Path +from pai_rag.core.rag_config_manager import RagConfigManager +from pai_rag.core.rag_module import resolve_data_analysis_tool +from pai_rag.integrations.data_analysis.data_analysis_config import SqlAnalysisConfig + +# from pai_rag.integrations.synthesizer.pai_synthesizer import PaiQueryBundle +from llama_index.core.schema import QueryBundle +import logging + +logger = logging.getLogger(__name__) + +_BASE_DIR = Path(__file__).parent.parent +DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") + + +@click.command() +@click.option( + "-c", + "--config_file", + show_default=True, + help=f"Configuration file. Default: {DEFAULT_APPLICATION_CONFIG_FILE}", + default=DEFAULT_APPLICATION_CONFIG_FILE, +) +# @click.option( +# "-l", +# "--input_list", +# type=list, +# required=True, +# help="input list", +# ) + +def run( + config_file=None, + # input_list=None, +): + + config = RagConfigManager.from_file(config_file).get_value() + print("config:", config) + + input_list = ["R5930 G2","0231A5QX"] + print("**Input List**: ", input_list) + + if isinstance(config.data_analysis, SqlAnalysisConfig): + da = resolve_data_analysis_tool(config) + + result = da.sql_query(input_list) + print("**Answer**: ", result) + print([item["物料编码"] for item in result]) + print(len(result)) + + + +if __name__ == "__main__": + run() From 7de150e9b82dc8b50e281136a2ec1337f704be6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Wed, 27 Nov 2024 16:23:27 +0800 Subject: [PATCH 5/6] Add custom nl2sql --- src/pai_rag/app/api/query.py | 54 ------------------ src/pai_rag/app/api/v1/chat.py | 56 +++++++++++++++++++ src/pai_rag/app/web/rag_client.py | 28 ++++++++++ src/pai_rag/app/web/tabs/data_analysis_tab.py | 17 +++++- src/pai_rag/core/rag_application.py | 7 +-- src/pai_rag/core/rag_service.py | 2 +- .../data_analysis/data_analysis_tool.py | 22 +++++--- src/pai_rag/tools/data_analysis_tool.py | 8 +-- 8 files changed, 119 insertions(+), 75 deletions(-) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 432445e1..e42bbafc 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -7,7 +7,6 @@ import tempfile import shutil import pandas as pd -import json from pai_rag.core.models.errors import UserInputError from pai_rag.core.rag_index_manager import RagIndexEntry, index_manager from pai_rag.core.rag_service import rag_service @@ -273,56 +272,3 @@ async def aquery_analysis(query: RagQuery): response, media_type="text/event-stream", ) - - -@router.post("/query/custom_search") -async def aquery_custom_test(query: RagQuery): - try: - response = await rag_service.aquery_llm(query) - - try: - answer = json.loads(response.answer) - except json.JSONDecodeError as e: - logger.error(f"Error decoding JSON: {e}") - return { - "status": "error", - "message": "Parsing Error: The LLM response is not a valid JSON format.", - "status_code": 400, - } - - input_list = [res.get("型号") for res in answer if "型号" in res] - logger.info(f"Extracted input list: {input_list}") - if not input_list: - logger.warning("No model information found in response.") - return { - "status": "error", - "message": "Parsing Error: The '型号' key is not found in the JSON.", - "status_code": 404, - } - - unique_input_list = list(set(input_list)) - logger.info(f"Unique input list: {unique_input_list}") - - # TODO - try: - sql_response = rag_service.sql_query(unique_input_list) - return { - "status": "success", - "data": sql_response, - "status_code": 200, - } - except Exception as e: - logger.error(f"SQL query failed: {e}") - return { - "status": "error", - "message": "SQL query failed: No information found for the relevant input list.", - "status_code": 500, - } - - except Exception as e: - logger.error(f"Unexpected error: {e}") - return { - "status": "error", - "message": "Unexpected error, please try again later.", - "status_code": 500, - } diff --git a/src/pai_rag/app/api/v1/chat.py b/src/pai_rag/app/api/v1/chat.py index dbe7ea2a..9e1fda77 100644 --- a/src/pai_rag/app/api/v1/chat.py +++ b/src/pai_rag/app/api/v1/chat.py @@ -4,6 +4,7 @@ import uuid import hashlib import os +import json import tempfile import shutil import pandas as pd @@ -272,3 +273,58 @@ async def aquery_analysis(query: RagQuery): response, media_type="text/event-stream", ) + + +@router_v1.post("/query/custom_search") +async def aquery_custom_test(query: RagQuery): + try: + response = await rag_service.aquery_llm(query) + + try: + answer = json.loads(response.answer) + except json.JSONDecodeError as e: + logger.error(f"Error decoding JSON: {e}") + return { + "status": "error", + "message": "Parsing Error: The LLM response is not a valid JSON format.", + "status_code": 400, + } + + input_list = [res.get("型号") for res in answer if "型号" in res] + logger.info(f"Extracted input list: {input_list}") + if not input_list: + logger.warning("No model information found in response.") + return { + "status": "error", + "message": "Parsing Error: The '型号' key is not found in the JSON.", + "status_code": 404, + } + + unique_input_list = list(set(input_list)) + logger.info(f"Unique input list: {unique_input_list}") + + try: + sql_response = rag_service.sql_query(unique_input_list) + return { + "status": "success", + "data": { + "input": unique_input_list, + "output": sql_response, + }, + "status_code": 200, + } + except Exception as e: + logger.error(f"SQL query failed: {e}") + return { + "status": "error", + "message": "SQL query failed: No information found for the relevant input list.", + "status_code": 500, + } + + except Exception as e: + logger.error(f"Unexpected error: {e}") + return { + "status": "error", + "message": "Unexpected error, please try again later.", + "status_code": 500, + } diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 9b11d38b..5fea36e9 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -55,6 +55,10 @@ def search_url(self): def data_analysis_url(self): return f"{self.endpoint}v1/query/data_analysis" + @property + def custom_search_url(self): + return f"{self.endpoint}v1/query/custom_search" + @property def llm_url(self): return f"{self.endpoint}v1/query/llm" @@ -287,6 +291,30 @@ def query_data_analysis( chunk_response.delta = full_content yield self._format_rag_response(text, chunk_response, stream=stream) + def query_custom_search( + self, + text: str, + with_history: bool = False, + stream: bool = False, + ): + session_id = None + q = dict( + question=text, + session_id=session_id, + stream=stream, + ) + r = requests.post(self.custom_search_url, json=q, stream=False) + if r.status_code != HTTPStatus.OK: + raise RagApiError(code=r.status_code, msg=r.text) + response_json = dotdict(r.json()) + output = json.dumps( + response_json["data"]["output"], ensure_ascii=False, indent=4 + ) + response_json[ + "result" + ] = f"**Extracted Info**: {response_json['data']['input']} \n\n **SQL Results**: \n```json {output}" + yield response_json + def query_llm( self, text: str, diff --git a/src/pai_rag/app/web/tabs/data_analysis_tab.py b/src/pai_rag/app/web/tabs/data_analysis_tab.py index 82208786..e3fadcd3 100644 --- a/src/pai_rag/app/web/tabs/data_analysis_tab.py +++ b/src/pai_rag/app/web/tabs/data_analysis_tab.py @@ -71,7 +71,11 @@ def respond(input_elements: List[Any]): chatbot.append((question, "")) try: - response_gen = rag_client.query_data_analysis(question, stream=True) + if update_dict["custom_sql_query"] == "custom": + response_gen = rag_client.query_custom_search(question, stream=False) + else: + response_gen = rag_client.query_data_analysis(question, stream=True) + for resp in response_gen: chatbot[-1] = (question, resp.result) yield chatbot @@ -157,6 +161,16 @@ def create_data_analysis_tab() -> Dict[str, Any]: elem_id="db_descriptions", placeholder='A dict of table descriptions, e.g. {"table_A": "text_description_A", "table_B": "text_description_B"}', ) + with gr.Row(): + custom_sql_query = gr.Radio( + choices=[ + "default", + "custom", + ], + value="default", + label="Please choose the custom sql query type", + elem_id="custom_sql_query", + ) with gr.Column(visible=True): with gr.Tab("Nl2sql Prompt"): sql_prompt_type = gr.Radio( @@ -268,6 +282,7 @@ def data_analysis_type_change(type_value): database, tables, descriptions, + custom_sql_query, db_nl2sql_prompt, synthesizer_prompt, question, diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index f50814b5..debe6571 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -367,10 +367,7 @@ async def aquery_analysis( else: return event_generator_async(response=response, sse_version=sse_version) - - def sql_query( - self, input_list: list, sse_version: SseVersion = SseVersion.V0 - ): + def sql_query(self, input_list: list, sse_version: SseVersion = SseVersion.V0): # session_id = query.session_id or uuid_generator() # logger.debug(f"Get session ID: {session_id}.") @@ -380,4 +377,4 @@ def sql_query( result = analysis_tool.sql_query(input_list) - return result \ No newline at end of file + return result diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 8b609e4f..d4ee7fc2 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -207,7 +207,7 @@ async def aquery_analysis_v1(self, query: RagQuery): except Exception as ex: logger.error(traceback.format_exc()) raise UserInputError(f"Query Analysis failed: {ex}") - + def sql_query(self, input_list: List): try: return self.rag.sql_query(input_list, sse_version=SseVersion.V1) diff --git a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py index 1370fecc..7180f7d7 100644 --- a/src/pai_rag/integrations/data_analysis/data_analysis_tool.py +++ b/src/pai_rag/integrations/data_analysis/data_analysis_tool.py @@ -24,6 +24,7 @@ from pai_rag.integrations.data_analysis.data_analysis_synthesizer import ( DataAnalysisSynthesizer, ) +from loguru import logger dispatcher = instrument.get_dispatcher(__name__) @@ -157,18 +158,20 @@ async def astream_query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return stream_response - def sql_query(self, input_list: List) -> List[dict]: """Query the material database directly.""" table_name = self._retriever._tables[0] - print("table:", table_name) - columns = [item["name"] for item in self._retriever._sql_database.get_table_columns(table_name)] - print("columns:", columns) + logger.info(f"table: {table_name}") + columns = [ + item["name"] + for item in self._retriever._sql_database.get_table_columns(table_name) + ] + logger.info(f"columns: {columns}") # 使用字符串格式化生成值列表 value_list = ", ".join(f""" "{code}" """.strip() for code in input_list) # 构建 SQL 查询 sql = f"SELECT * FROM material_data WHERE 物料编码 IN ({value_list})" - print("sql:", sql) + logger.info(f"sql: {sql}") try: with self._retriever._sql_database.engine.connect() as connection: result = connection.execution_options(timeout=60).execute(text(sql)) @@ -176,7 +179,10 @@ def sql_query(self, input_list: List) -> List[dict]: result_json = [dict(zip(columns, sublist)) for sublist in query_results] return result_json except NotImplementedError as error: - raise NotImplementedError(f"SQL execution not implemented: {error}") from error + raise NotImplementedError( + f"SQL execution not implemented: {error}" + ) from error except Exception as error: - raise Exception(f"Unexpected error during SQL execution: {error}") from error - + raise Exception( + f"Unexpected error during SQL execution: {error}" + ) from error diff --git a/src/pai_rag/tools/data_analysis_tool.py b/src/pai_rag/tools/data_analysis_tool.py index 285d281f..02dc53a0 100644 --- a/src/pai_rag/tools/data_analysis_tool.py +++ b/src/pai_rag/tools/data_analysis_tool.py @@ -1,14 +1,11 @@ import click import os -import time -import sys from pathlib import Path from pai_rag.core.rag_config_manager import RagConfigManager from pai_rag.core.rag_module import resolve_data_analysis_tool from pai_rag.integrations.data_analysis.data_analysis_config import SqlAnalysisConfig # from pai_rag.integrations.synthesizer.pai_synthesizer import PaiQueryBundle -from llama_index.core.schema import QueryBundle import logging logger = logging.getLogger(__name__) @@ -33,15 +30,15 @@ # help="input list", # ) + def run( config_file=None, # input_list=None, ): - config = RagConfigManager.from_file(config_file).get_value() print("config:", config) - input_list = ["R5930 G2","0231A5QX"] + input_list = ["R5930 G2", "0231A5QX"] print("**Input List**: ", input_list) if isinstance(config.data_analysis, SqlAnalysisConfig): @@ -53,6 +50,5 @@ def run( print(len(result)) - if __name__ == "__main__": run() From 7bd3e3d764a2fd5d1bc8e1dc35d91f6603826b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Wed, 27 Nov 2024 16:30:25 +0800 Subject: [PATCH 6/6] Use UserInputError --- src/pai_rag/app/api/v1/chat.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/pai_rag/app/api/v1/chat.py b/src/pai_rag/app/api/v1/chat.py index 9e1fda77..01f59406 100644 --- a/src/pai_rag/app/api/v1/chat.py +++ b/src/pai_rag/app/api/v1/chat.py @@ -8,7 +8,7 @@ import tempfile import shutil import pandas as pd -from pai_rag.core.models.errors import UserInputError +from pai_rag.core.models.errors import UserInputError, ServiceError from pai_rag.core.rag_index_manager import RagIndexEntry, index_manager from pai_rag.core.rag_service import rag_service from pai_rag.app.api.models import ( @@ -284,21 +284,16 @@ async def aquery_custom_test(query: RagQuery): answer = json.loads(response.answer) except json.JSONDecodeError as e: logger.error(f"Error decoding JSON: {e}") - return { - "status": "error", - "message": "Parsing Error: The LLM response is not a valid JSON format.", - "status_code": 400, - } - + raise UserInputError( + "Parsing Error: The LLM response is not a valid JSON format." + ) input_list = [res.get("型号") for res in answer if "型号" in res] logger.info(f"Extracted input list: {input_list}") if not input_list: logger.warning("No model information found in response.") - return { - "status": "error", - "message": "Parsing Error: The '型号' key is not found in the JSON.", - "status_code": 404, - } + raise UserInputError( + "Parsing Error: The '型号' key is not found in the JSON." + ) unique_input_list = list(set(input_list)) logger.info(f"Unique input list: {unique_input_list}") @@ -315,16 +310,8 @@ async def aquery_custom_test(query: RagQuery): } except Exception as e: logger.error(f"SQL query failed: {e}") - return { - "status": "error", - "message": "SQL query failed: No information found for the relevant input list.", - "status_code": 500, - } + raise ServiceError("SQL query failed.") except Exception as e: logger.error(f"Unexpected error: {e}") - return { - "status": "error", - "message": "Unexpected error, please try again later.", - "status_code": 500, - } + raise ServiceError("Unexpected error, please try again later.")