From 9d2d26dd7800c722fe1e9400afafc05ec3a48b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Tue, 24 Dec 2024 17:43:05 +0800 Subject: [PATCH 1/3] Fix chunking --- src/pai_rag/app/web/rag_client.py | 3 ++ .../nodeparsers/pai/pai_node_parser.py | 1 + .../integrations/readers/pai_csv_reader.py | 12 ++++- .../readers/utils/pai_parse_workbook.py | 53 +++++++++++++++---- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 1eddd299..f9cf29a3 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -137,6 +137,7 @@ def _format_rag_response( self.session_id = session_id for i, doc in enumerate(docs): filename = doc["metadata"].get("file_name", None) + sheet_name = doc["metadata"].get("sheet_name", None) ref_table = doc["metadata"].get("query_tables", None) invalid_flag = doc["metadata"].get("invalid_flag", 0) file_url = doc["metadata"].get("file_url", None) @@ -156,6 +157,8 @@ def _format_rag_response( """ elif filename: formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename) + if sheet_name: + formatted_file_name += f">>{sheet_name}" html_content = html.escape( re.sub(r"<.*?>", "", doc["text"]) ).replace("\n", " ") diff --git a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py index f764770b..c7d087be 100644 --- a/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py +++ b/src/pai_rag/integrations/nodeparsers/pai/pai_node_parser.py @@ -61,6 +61,7 @@ class NodeParserConfig(BaseModel): "source", "row_number", "image_info_list", + "file_url", ] diff --git a/src/pai_rag/integrations/readers/pai_csv_reader.py b/src/pai_rag/integrations/readers/pai_csv_reader.py index 48ba3137..b89266da 100644 --- a/src/pai_rag/integrations/readers/pai_csv_reader.py +++ b/src/pai_rag/integrations/readers/pai_csv_reader.py @@ -13,6 +13,7 @@ from llama_index.core.schema import Document from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.node_parser.text.utils import split_by_sep import chardet import os @@ -210,11 +211,18 @@ def __init__( def load_data( self, file: Path, - chunk_size=3000, + chunk_size=800, + chunk_overlap=60, extra_info: Optional[Dict] = None, fs: Optional[AbstractFileSystem] = None, ) -> List[Document]: - splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=60) + splitter = SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + paragraph_separator="\n\n\n", + chunking_tokenizer_fn=split_by_sep("\n\n"), + ) + logger.info(f"Start parsing {file}.") docs = parse_workbook(file, oss_client=self.oss_cache, splitter=splitter) for doc in docs: diff --git a/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py b/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py index 49f60832..feb11171 100644 --- a/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py +++ b/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py @@ -18,10 +18,41 @@ def __init__(self, title, data, header_row: int = -1): self.header_row = header_row +def format_text(text): + if not isinstance(text, str): + return str(text) + + # Remove

tags + text = re.sub(r"]*>", "", text) + text = re.sub(r"

", "", text) + + # Remove tags + text = re.sub(r"]*>", "", text) + text = re.sub(r"", "", text) + + # Remove tags + text = re.sub(r"]*>", "", text) + text = re.sub(r"", "", text) + + text = re.sub(r"]*>", "", text) + text = re.sub(r"", "", text) + + text = text.replace(" ", " ") + text = text.replace(""", '"') + text = re.sub("\s\s+", " ", text) + + return text + + def split_row_group(row_group, headers=[], splitter=None, form_title=None): """ Split a row group into smaller pieces. """ + row_size_limit = 1200 + + if len(row_group) == 1: + row_size_limit = 3000 + raw_text = "" form_title = form_title + "\n\n" title_text = "" @@ -48,14 +79,11 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): if is_same_value: if len(headers) == 0: - column_text = f"{first_value}\n\n\n" + column_text = f"{format_text(first_value)}\n\n\n" else: - column_text = f"{headers[j]}: {first_value}\n\n\n" + column_text = f"{headers[j]}: {format_text(first_value)}\n\n\n" - if len(column_text) <= 30: - title_text += column_text - else: - is_same_value = False + title_text += column_text is_outline_column.append(is_same_value) @@ -68,9 +96,9 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): if not row_group[i][j]: continue else: - raw_text += f"{row_group[i][j]}\n" + raw_text += f"{format_text(row_group[i][j])}\n" else: - raw_text += f"{headers[j]}: {row_group[i][j]}\n" + raw_text += f"{headers[j]}: {format_text(row_group[i][j])}\n" raw_text += "\n\n" @@ -85,7 +113,7 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): raw_text = re.sub(IMAGE_REGEX, "", raw_text) title_text = re.sub(IMAGE_REGEX, "", title_text) - if len(raw_text) < 3000: + if len(raw_text) < row_size_limit: return [ Document( text=form_title + title_text + raw_text, @@ -93,12 +121,16 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): ) ] else: + if len(row_group) == 1: + chunk_size = 3000 + else: + chunk_size = 800 return [ Document( text=form_title + title_text + split, extra_info={"image_info_list": image_info_list}, ) - for split in splitter.split_text(raw_text) + for split in splitter._split_text(raw_text, chunk_size=chunk_size) ] @@ -152,6 +184,7 @@ def chunk_form(form_title, form_data, header_row=-1, splitter=None): values[i + 1][j] is not None and values[i + 1][j] != "" and values[i + 1][j] == values[i][j] + and len(values[i + 1][j]) < 150 ): should_merge = True break From 5956100fd9a1ff382c640b8164de533a38bcb3b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Tue, 24 Dec 2024 18:05:00 +0800 Subject: [PATCH 2/3] Fix chunking --- .../integrations/readers/utils/pai_parse_workbook.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py b/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py index feb11171..93bfd06a 100644 --- a/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py +++ b/src/pai_rag/integrations/readers/utils/pai_parse_workbook.py @@ -65,8 +65,10 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): ), f"Header and row data length mismatch! headers: {headers}, row: {row_group[0]}" is_outline_column = [] + for j in range(len(row_group[0])): first_value = row_group[0][j] + if not first_value: is_outline_column.append(False) continue @@ -83,7 +85,10 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None): else: column_text = f"{headers[j]}: {format_text(first_value)}\n\n\n" - title_text += column_text + if len(column_text) < row_size_limit: + title_text += column_text + else: + is_same_value = False is_outline_column.append(is_same_value) From cddbd070a06968125910b51485a28472221f698a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Tue, 24 Dec 2024 18:58:25 +0800 Subject: [PATCH 3/3] Add chatonly --- src/pai_rag/app/web/chatonly_page.py | 101 +++++++++++++++++++++++++++ src/pai_rag/app/web/webui.py | 9 ++- 2 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 src/pai_rag/app/web/chatonly_page.py diff --git a/src/pai_rag/app/web/chatonly_page.py b/src/pai_rag/app/web/chatonly_page.py new file mode 100644 index 00000000..5c576321 --- /dev/null +++ b/src/pai_rag/app/web/chatonly_page.py @@ -0,0 +1,101 @@ +import gradio as gr +from pai_rag.app.web.rag_client import RagApiError, rag_client + + +def clear_history(chatbot): + rag_client.clear_history() + chatbot = [] + return chatbot + + +def reset_textbox(): + return gr.update(value="") + + +def respond(retrieve_only, question, chatbot): + # empty input. + if not question: + yield chatbot + return + + if chatbot is not None: + chatbot.append((question, "")) + yield chatbot + + try: + if retrieve_only: + response_gen = rag_client.query_vector(question, index_name="default_index") + + else: + response_gen = rag_client.query( + question, + with_history=False, + stream=True, + citation=True, + index_name="default_index", + ) + for resp in response_gen: + chatbot[-1] = (question, resp.result) + yield chatbot + + except RagApiError as api_error: + raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}") + except Exception as e: + raise gr.Error(f"Error: {e}") + finally: + yield chatbot + + +def create_chat_ui(): + with gr.Blocks() as chatpage: + chatbot = gr.Chatbot(height=600, elem_id="chatbot") + with gr.Row(): + retrieve_only = gr.Checkbox( + label="Retrieve only", + info="Query knowledge base directly without LLM.", + elem_id="retrieve_only", + value=True, + scale=1, + ) + question = gr.Textbox( + label="Enter your question.", elem_id="question", scale=9 + ) + with gr.Row(): + submitBtn = gr.Button("Submit", variant="primary") + clearBtn = gr.Button("Clear History", variant="secondary") + + submitBtn.click( + respond, + [ + retrieve_only, + question, + chatbot, + ], + [chatbot], + api_name="respond_clk", + ) + question.submit( + respond, + [ + retrieve_only, + question, + chatbot, + ], + [chatbot], + api_name="respond_q", + ) + submitBtn.click( + reset_textbox, + [], + [question], + api_name="reset_clk", + ) + question.submit( + reset_textbox, + [], + [question], + api_name="reset_q", + ) + + clearBtn.click(clear_history, [chatbot], [chatbot]) + return chatpage diff --git a/src/pai_rag/app/web/webui.py b/src/pai_rag/app/web/webui.py index 33150ac4..41c0ca37 100644 --- a/src/pai_rag/app/web/webui.py +++ b/src/pai_rag/app/web/webui.py @@ -9,6 +9,7 @@ from pai_rag.app.web.tabs.upload_tab import create_upload_tab from pai_rag.app.web.tabs.chat_tab import create_chat_tab from pai_rag.app.web.tabs.data_analysis_tab import create_data_analysis_tab +from pai_rag.app.web.chatonly_page import create_chat_ui from pai_rag.app.web.index_utils import index_related_component_keys # from pai_rag.app.web.tabs.eval_tab import create_evaluation_tab @@ -133,7 +134,11 @@ def configure_webapp(app: FastAPI, web_url, rag_url=DEFAULT_LOCAL_URL) -> gr.Blo rag_client.set_endpoint(rag_url) home = make_homepage() home.queue(concurrency_count=1, max_size=64) - home._queue.set_url(web_url) + # home._queue.set_url(web_url) logger.info(f"web_url: {web_url}") - gr.mount_gradio_app(app, home, path="") + + chat_page = create_chat_ui() + chat_page.queue(concurrency_count=1, max_size=64) + gr.mount_gradio_app(app, chat_page, path="/chat/") + gr.mount_gradio_app(app, home, path="/") return home