From 280b319dd27ff08373a609fa322ef6e995037241 Mon Sep 17 00:00:00 2001 From: aero-xi Date: Tue, 10 Dec 2024 17:29:16 +0800 Subject: [PATCH] Update sql parser (#304) * add description to synthesizer * make lint * fix bug * fix tests * udpate sql parse * udpate sql parse --- .../data_analysis/nl2sql_retriever.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py index ebfec1d4..b3e30118 100644 --- a/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py +++ b/src/pai_rag/integrations/data_analysis/nl2sql_retriever.py @@ -227,18 +227,31 @@ def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str class DefaultSQLParser(BaseSQLParser): """Default SQL Parser.""" + # def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str: + # """Parse response to SQL.""" + # sql_query_start = response.find("SQLQuery:") + # if sql_query_start != -1: + # response = response[sql_query_start:] + # # TODO: move to removeprefix after Python 3.9+ + # if response.startswith("SQLQuery:"): + # response = response[len("SQLQuery:") :] + # sql_result_start = response.find("SQLResult:") + # if sql_result_start != -1: + # response = response[:sql_result_start] + # return response.strip().strip("```").strip().strip(";").strip().lstrip("sql") + def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str: """Parse response to SQL.""" sql_query_start = response.find("SQLQuery:") - if sql_query_start != -1: + if sql_query_start != -1: # -1 means not found response = response[sql_query_start:] # TODO: move to removeprefix after Python 3.9+ if response.startswith("SQLQuery:"): response = response[len("SQLQuery:") :] - sql_result_start = response.find("SQLResult:") - if sql_result_start != -1: - response = response[:sql_result_start] - return response.strip().strip("```").strip().strip(";").strip().lstrip("sql") + sql_query_end = response.find(";") + if sql_query_end != -1: + response = response[:sql_query_end].rstrip().replace("```", "") + return response.strip().replace("```", "").lstrip("sql") def get_sql_info(sql_config: SqlAnalysisConfig):