Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix chunking #319

Merged
merged 5 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions src/pai_rag/app/web/chatonly_page.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import gradio as gr
from pai_rag.app.web.rag_client import RagApiError, rag_client


def clear_history(chatbot):
rag_client.clear_history()
chatbot = []
return chatbot


def reset_textbox():
return gr.update(value="")


def respond(retrieve_only, question, chatbot):
# empty input.
if not question:
yield chatbot
return

if chatbot is not None:
chatbot.append((question, ""))
yield chatbot

try:
if retrieve_only:
response_gen = rag_client.query_vector(question, index_name="default_index")

else:
response_gen = rag_client.query(
question,
with_history=False,
stream=True,
citation=True,
index_name="default_index",
)
for resp in response_gen:
chatbot[-1] = (question, resp.result)
yield chatbot

except RagApiError as api_error:
raise gr.Error(f"HTTP {api_error.code} Error: {api_error.msg}")
except Exception as e:
raise gr.Error(f"Error: {e}")
finally:
yield chatbot


def create_chat_ui():
with gr.Blocks() as chatpage:
chatbot = gr.Chatbot(height=600, elem_id="chatbot")
with gr.Row():
retrieve_only = gr.Checkbox(
label="Retrieve only",
info="Query knowledge base directly without LLM.",
elem_id="retrieve_only",
value=True,
scale=1,
)
question = gr.Textbox(
label="Enter your question.", elem_id="question", scale=9
)
with gr.Row():
submitBtn = gr.Button("Submit", variant="primary")
clearBtn = gr.Button("Clear History", variant="secondary")

submitBtn.click(
respond,
[
retrieve_only,
question,
chatbot,
],
[chatbot],
api_name="respond_clk",
)
question.submit(
respond,
[
retrieve_only,
question,
chatbot,
],
[chatbot],
api_name="respond_q",
)
submitBtn.click(
reset_textbox,
[],
[question],
api_name="reset_clk",
)
question.submit(
reset_textbox,
[],
[question],
api_name="reset_q",
)

clearBtn.click(clear_history, [chatbot], [chatbot])
return chatpage
3 changes: 3 additions & 0 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _format_rag_response(
self.session_id = session_id
for i, doc in enumerate(docs):
filename = doc["metadata"].get("file_name", None)
sheet_name = doc["metadata"].get("sheet_name", None)
ref_table = doc["metadata"].get("query_tables", None)
invalid_flag = doc["metadata"].get("invalid_flag", 0)
file_url = doc["metadata"].get("file_url", None)
Expand All @@ -156,6 +157,8 @@ def _format_rag_response(
"""
elif filename:
formatted_file_name = re.sub("^[0-9a-z]{32}_", "", filename)
if sheet_name:
formatted_file_name += f">>{sheet_name}"
html_content = html.escape(
re.sub(r"<.*?>", "", doc["text"])
).replace("\n", " ")
Expand Down
9 changes: 7 additions & 2 deletions src/pai_rag/app/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pai_rag.app.web.tabs.upload_tab import create_upload_tab
from pai_rag.app.web.tabs.chat_tab import create_chat_tab
from pai_rag.app.web.tabs.data_analysis_tab import create_data_analysis_tab
from pai_rag.app.web.chatonly_page import create_chat_ui
from pai_rag.app.web.index_utils import index_related_component_keys

# from pai_rag.app.web.tabs.eval_tab import create_evaluation_tab
Expand Down Expand Up @@ -133,7 +134,11 @@ def configure_webapp(app: FastAPI, web_url, rag_url=DEFAULT_LOCAL_URL) -> gr.Blo
rag_client.set_endpoint(rag_url)
home = make_homepage()
home.queue(concurrency_count=1, max_size=64)
home._queue.set_url(web_url)
# home._queue.set_url(web_url)
logger.info(f"web_url: {web_url}")
gr.mount_gradio_app(app, home, path="")

chat_page = create_chat_ui()
chat_page.queue(concurrency_count=1, max_size=64)
gr.mount_gradio_app(app, chat_page, path="/chat/")
gr.mount_gradio_app(app, home, path="/")
return home
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class NodeParserConfig(BaseModel):
"source",
"row_number",
"image_info_list",
"file_url",
]


Expand Down
12 changes: 10 additions & 2 deletions src/pai_rag/integrations/readers/pai_csv_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from llama_index.core.schema import Document

from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.node_parser.text.utils import split_by_sep

import chardet
import os
Expand Down Expand Up @@ -210,11 +211,18 @@ def __init__(
def load_data(
self,
file: Path,
chunk_size=3000,
chunk_size=800,
chunk_overlap=60,
extra_info: Optional[Dict] = None,
fs: Optional[AbstractFileSystem] = None,
) -> List[Document]:
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=60)
splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
paragraph_separator="\n\n\n",
chunking_tokenizer_fn=split_by_sep("\n\n"),
)

logger.info(f"Start parsing {file}.")
docs = parse_workbook(file, oss_client=self.oss_cache, splitter=splitter)
for doc in docs:
Expand Down
52 changes: 45 additions & 7 deletions src/pai_rag/integrations/readers/utils/pai_parse_workbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,41 @@ def __init__(self, title, data, header_row: int = -1):
self.header_row = header_row


def format_text(text):
if not isinstance(text, str):
return str(text)

# Remove <p> tags
text = re.sub(r"<p[^>]*>", "", text)
text = re.sub(r"</p>", "", text)

# Remove <strong> tags
text = re.sub(r"<strong[^>]*>", "", text)
text = re.sub(r"</strong>", "", text)

# Remove <span> tags
text = re.sub(r"<span[^>]*>", "", text)
text = re.sub(r"</span>", "", text)

text = re.sub(r"<font[^>]*>", "", text)
text = re.sub(r"</font>", "", text)

text = text.replace("&nbsp;", " ")
text = text.replace("&quot;", '"')
text = re.sub("\s\s+", " ", text)

return text


def split_row_group(row_group, headers=[], splitter=None, form_title=None):
"""
Split a row group into smaller pieces.
"""
row_size_limit = 1200

if len(row_group) == 1:
row_size_limit = 3000

raw_text = ""
form_title = form_title + "\n\n"
title_text = ""
Expand All @@ -34,8 +65,10 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None):
), f"Header and row data length mismatch! headers: {headers}, row: {row_group[0]}"

is_outline_column = []

for j in range(len(row_group[0])):
first_value = row_group[0][j]

if not first_value:
is_outline_column.append(False)
continue
Expand All @@ -48,11 +81,11 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None):

if is_same_value:
if len(headers) == 0:
column_text = f"{first_value}\n\n\n"
column_text = f"{format_text(first_value)}\n\n\n"
else:
column_text = f"{headers[j]}: {first_value}\n\n\n"
column_text = f"{headers[j]}: {format_text(first_value)}\n\n\n"

if len(column_text) <= 30:
if len(column_text) < row_size_limit:
title_text += column_text
else:
is_same_value = False
Expand All @@ -68,9 +101,9 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None):
if not row_group[i][j]:
continue
else:
raw_text += f"{row_group[i][j]}\n"
raw_text += f"{format_text(row_group[i][j])}\n"
else:
raw_text += f"{headers[j]}: {row_group[i][j]}\n"
raw_text += f"{headers[j]}: {format_text(row_group[i][j])}\n"

raw_text += "\n\n"

Expand All @@ -85,20 +118,24 @@ def split_row_group(row_group, headers=[], splitter=None, form_title=None):
raw_text = re.sub(IMAGE_REGEX, "", raw_text)
title_text = re.sub(IMAGE_REGEX, "", title_text)

if len(raw_text) < 3000:
if len(raw_text) < row_size_limit:
return [
Document(
text=form_title + title_text + raw_text,
extra_info={"image_info_list": image_info_list},
)
]
else:
if len(row_group) == 1:
chunk_size = 3000
else:
chunk_size = 800
return [
Document(
text=form_title + title_text + split,
extra_info={"image_info_list": image_info_list},
)
for split in splitter.split_text(raw_text)
for split in splitter._split_text(raw_text, chunk_size=chunk_size)
]


Expand Down Expand Up @@ -152,6 +189,7 @@ def chunk_form(form_title, form_data, header_row=-1, splitter=None):
values[i + 1][j] is not None
and values[i + 1][j] != ""
and values[i + 1][j] == values[i][j]
and len(values[i + 1][j]) < 150
):
should_merge = True
break
Expand Down
Loading