Skip to content

Commit

Permalink
Update sql parser (#304)
Browse files Browse the repository at this point in the history
* add description to synthesizer

* make lint

* fix bug

* fix tests

* udpate sql parse

* udpate sql parse
  • Loading branch information
aero-xi authored Dec 10, 2024
1 parent beeaf52 commit 280b319
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions src/pai_rag/integrations/data_analysis/nl2sql_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,31 @@ def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str
class DefaultSQLParser(BaseSQLParser):
"""Default SQL Parser."""

# def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str:
# """Parse response to SQL."""
# sql_query_start = response.find("SQLQuery:")
# if sql_query_start != -1:
# response = response[sql_query_start:]
# # TODO: move to removeprefix after Python 3.9+
# if response.startswith("SQLQuery:"):
# response = response[len("SQLQuery:") :]
# sql_result_start = response.find("SQLResult:")
# if sql_result_start != -1:
# response = response[:sql_result_start]
# return response.strip().strip("```").strip().strip(";").strip().lstrip("sql")

def parse_response_to_sql(self, response: str, query_bundle: QueryBundle) -> str:
"""Parse response to SQL."""
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
if sql_query_start != -1: # -1 means not found
response = response[sql_query_start:]
# TODO: move to removeprefix after Python 3.9+
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip().strip(";").strip().lstrip("sql")
sql_query_end = response.find(";")
if sql_query_end != -1:
response = response[:sql_query_end].rstrip().replace("```", "")
return response.strip().replace("```", "").lstrip("sql")


def get_sql_info(sql_config: SqlAnalysisConfig):
Expand Down

0 comments on commit 280b319

Please sign in to comment.