Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom query api #284

Closed
wants to merge 10 commits into from
54 changes: 54 additions & 0 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import tempfile
import shutil
import pandas as pd
import json
from pai_rag.core.models.errors import UserInputError
from pai_rag.core.rag_index_manager import RagIndexEntry, index_manager
from pai_rag.core.rag_service import rag_service
Expand Down Expand Up @@ -272,3 +273,56 @@ async def aquery_analysis(query: RagQuery):
response,
media_type="text/event-stream",
)


@router.post("/query/custom_search")
async def aquery_custom_test(query: RagQuery):
try:
response = await rag_service.aquery_llm(query)

try:
answer = json.loads(response.answer)
except json.JSONDecodeError as e:
logger.error(f"Error decoding JSON: {e}")
return {
wwxxzz marked this conversation as resolved.
Show resolved Hide resolved
"status": "error",
"message": "Parsing Error: The LLM response is not a valid JSON format.",
"status_code": 400,
}

input_list = [res.get("型号") for res in answer if "型号" in res]
logger.info(f"Extracted input list: {input_list}")
if not input_list:
logger.warning("No model information found in response.")
return {
"status": "error",
"message": "Parsing Error: The '型号' key is not found in the JSON.",
"status_code": 404,
}

unique_input_list = list(set(input_list))
logger.info(f"Unique input list: {unique_input_list}")

# TODO
try:
sql_response = rag_service.sql_query(unique_input_list)
return {
"status": "success",
"data": sql_response,
"status_code": 200,
}
except Exception as e:
logger.error(f"SQL query failed: {e}")
return {
"status": "error",
"message": "SQL query failed: No information found for the relevant input list.",
"status_code": 500,
}

except Exception as e:
logger.error(f"Unexpected error: {e}")
return {
"status": "error",
"message": "Unexpected error, please try again later.",
"status_code": 500,
}
15 changes: 15 additions & 0 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,18 @@ async def aquery_analysis(
return RagResponse(answer=response.response, **result_info)
else:
return event_generator_async(response=response, sse_version=sse_version)


def sql_query(
self, input_list: list, sse_version: SseVersion = SseVersion.V0
):
# session_id = query.session_id or uuid_generator()
# logger.debug(f"Get session ID: {session_id}.")

analysis_tool = resolve_data_analysis_tool(self.config)
if not analysis_tool:
raise ValueError("Data Analysis not enabled. Please specify analysis type.")

result = analysis_tool.sql_query(input_list)

return result
7 changes: 7 additions & 0 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ async def aquery_analysis_v1(self, query: RagQuery):
except Exception as ex:
logger.error(traceback.format_exc())
raise UserInputError(f"Query Analysis failed: {ex}")

def sql_query(self, input_list: List):
try:
return self.rag.sql_query(input_list, sse_version=SseVersion.V1)
except Exception as ex:
logger.error(traceback.format_exc())
raise UserInputError(f"SQL Query failed: {ex}")


rag_service = RagService()
26 changes: 26 additions & 0 deletions src/pai_rag/integrations/data_analysis/data_analysis_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from llama_index.core.settings import Settings
import llama_index.core.instrumentation as instrument

from sqlalchemy import text

from pai_rag.integrations.data_analysis.nl2sql_retriever import MyNLSQLRetriever
from pai_rag.integrations.data_analysis.data_analysis_config import (
BaseAnalysisConfig,
Expand Down Expand Up @@ -154,3 +156,27 @@ async def astream_query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
self._synthesizer._streaming = streaming

return stream_response


def sql_query(self, input_list: List) -> List[dict]:
"""Query the material database directly."""
table_name = self._retriever._tables[0]
print("table:", table_name)
columns = [item["name"] for item in self._retriever._sql_database.get_table_columns(table_name)]
print("columns:", columns)
# 使用字符串格式化生成值列表
value_list = ", ".join(f""" "{code}" """.strip() for code in input_list)
# 构建 SQL 查询
sql = f"SELECT * FROM material_data WHERE 物料编码 IN ({value_list})"
print("sql:", sql)
try:
with self._retriever._sql_database.engine.connect() as connection:
result = connection.execution_options(timeout=60).execute(text(sql))
query_results = result.fetchall()
result_json = [dict(zip(columns, sublist)) for sublist in query_results]
return result_json
except NotImplementedError as error:
raise NotImplementedError(f"SQL execution not implemented: {error}") from error
except Exception as error:
raise Exception(f"Unexpected error during SQL execution: {error}") from error

58 changes: 58 additions & 0 deletions src/pai_rag/tools/data_analysis_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import click
import os
import time
import sys
from pathlib import Path
from pai_rag.core.rag_config_manager import RagConfigManager
from pai_rag.core.rag_module import resolve_data_analysis_tool
from pai_rag.integrations.data_analysis.data_analysis_config import SqlAnalysisConfig

# from pai_rag.integrations.synthesizer.pai_synthesizer import PaiQueryBundle
from llama_index.core.schema import QueryBundle
import logging

logger = logging.getLogger(__name__)

_BASE_DIR = Path(__file__).parent.parent
DEFAULT_APPLICATION_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml")


@click.command()
@click.option(
"-c",
"--config_file",
show_default=True,
help=f"Configuration file. Default: {DEFAULT_APPLICATION_CONFIG_FILE}",
default=DEFAULT_APPLICATION_CONFIG_FILE,
)
# @click.option(
# "-l",
# "--input_list",
# type=list,
# required=True,
# help="input list",
# )

def run(
config_file=None,
# input_list=None,
):

config = RagConfigManager.from_file(config_file).get_value()
print("config:", config)

input_list = ["R5930 G2","0231A5QX"]
print("**Input List**: ", input_list)

if isinstance(config.data_analysis, SqlAnalysisConfig):
da = resolve_data_analysis_tool(config)

result = da.sql_query(input_list)
print("**Answer**: ", result)
print([item["物料编码"] for item in result])
print(len(result))



if __name__ == "__main__":
run()
Loading