From ef7090b63fd1facca6a9cca2cb6b8b6c04404c56 Mon Sep 17 00:00:00 2001 From: wwxxzz Date: Thu, 13 Jun 2024 16:44:49 +0800 Subject: [PATCH] Bugfix: connection error for longtime upload tasks (#62) * Fix connection error for longtime job * fix testcase bugs * support num workers for embedding model * Refactor query api and add dataframe UI * Refactor query api * Remove embedding workers --- src/pai_rag/app/api/query.py | 23 +++++++--- src/pai_rag/app/web/rag_client.py | 15 ++++++- src/pai_rag/app/web/tabs/upload_tab.py | 51 +++++++++++++++++----- src/pai_rag/core/rag_application.py | 4 +- src/pai_rag/core/rag_service.py | 17 ++++++-- src/pai_rag/data/rag_dataloader.py | 11 ++++- src/pai_rag/modules/embedding/embedding.py | 3 +- tests/core/test_rag_application.py | 4 +- 8 files changed, 101 insertions(+), 27 deletions(-) diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 60bd0d54..dc0ec7fc 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -1,5 +1,6 @@ from typing import Any -from fastapi import APIRouter, Body +from fastapi import APIRouter, Body, BackgroundTasks +import uuid from pai_rag.core.rag_service import rag_service from pai_rag.app.api.models import ( RagQuery, @@ -39,12 +40,22 @@ async def aupdate(new_config: Any = Body(None)): return {"msg": "Update RAG configuration successfully."} -@router.post("/data") -async def load_data(input: DataInput): - await rag_service.add_knowledge( - file_dir=input.file_path, enable_qa_extraction=input.enable_qa_extraction +@router.post("/upload_data") +async def load_data(input: DataInput, background_tasks: BackgroundTasks): + task_id = uuid.uuid4().hex + background_tasks.add_task( + rag_service.add_knowledge_async, + task_id=task_id, + file_dir=input.file_path, + enable_qa_extraction=input.enable_qa_extraction, ) - return {"msg": "Update RAG configuration successfully."} + return {"task_id": task_id} + + +@router.get("/get_upload_state") +def task_status(task_id: str): + status = rag_service.get_task_status(task_id) + return {"task_id": task_id, "status": status} @router.post("/evaluate/response") diff --git a/src/pai_rag/app/web/rag_client.py b/src/pai_rag/app/web/rag_client.py index 35e88508..118b923d 100644 --- a/src/pai_rag/app/web/rag_client.py +++ b/src/pai_rag/app/web/rag_client.py @@ -41,7 +41,11 @@ def config_url(self): @property def load_data_url(self): - return f"{self.endpoint}service/data" + return f"{self.endpoint}service/upload_data" + + @property + def get_load_state_url(self): + return f"{self.endpoint}service/get_upload_state" def query(self, text: str, session_id: str = None): q = dict(question=text, session_id=session_id) @@ -88,7 +92,14 @@ def add_knowledge(self, file_dir: str, enable_qa_extraction: bool): q = dict(file_path=file_dir, enable_qa_extraction=enable_qa_extraction) r = requests.post(self.load_data_url, json=q) r.raise_for_status() - return + response = dotdict(json.loads(r.text)) + return response + + def get_knowledge_state(self, task_id: str): + r = requests.get(self.get_load_state_url, params={"task_id": task_id}) + r.raise_for_status() + response = dotdict(json.loads(r.text)) + return response def reload_config(self, config: Any): global cache_config diff --git a/src/pai_rag/app/web/tabs/upload_tab.py b/src/pai_rag/app/web/tabs/upload_tab.py index 8dba2c6b..96e4d7a2 100644 --- a/src/pai_rag/app/web/tabs/upload_tab.py +++ b/src/pai_rag/app/web/tabs/upload_tab.py @@ -1,8 +1,11 @@ import os from typing import Dict, Any import gradio as gr +import time from pai_rag.app.web.rag_client import rag_client from pai_rag.app.web.view_model import view_model +from pai_rag.utils.file_utils import MyUploadFile +import pandas as pd def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction): @@ -14,14 +17,36 @@ def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extracti if not upload_files: return "No file selected. Please choose at least one file." + my_upload_files = [] for file in upload_files: file_dir = os.path.dirname(file.name) - rag_client.add_knowledge(file_dir, enable_qa_extraction) - return ( - "Upload " - + str(len(upload_files)) - + " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions." - ) + response = rag_client.add_knowledge(file_dir, enable_qa_extraction) + my_upload_files.append( + MyUploadFile(os.path.basename(file.name), response["task_id"]) + ) + + result = {"Info": ["StartTime", "EndTime", "Duration(s)", "Status"]} + while not all(file.finished is True for file in my_upload_files): + for file in my_upload_files: + response = rag_client.get_knowledge_state(str(file.task_id)) + file.update_state(response["status"]) + file.update_process_duration() + result[file.file_name] = file.__info__() + if response["status"] in ["completed", "failed"]: + file.is_finished() + yield [ + gr.update(visible=True, value=pd.DataFrame(result)), + gr.update(visible=False), + ] + time.sleep(2) + + yield [ + gr.update(visible=True, value=pd.DataFrame(result)), + gr.update( + visible=True, + value="Uploaded all files successfully! \n Relevant content has been added to the vector store, you can now start chatting and asking questions.", + ), + ] def create_upload_tab() -> Dict[str, Any]: @@ -47,14 +72,20 @@ def create_upload_tab() -> Dict[str, Any]: label="Upload a knowledge file.", file_count="multiple" ) upload_file_btn = gr.Button("Upload", variant="primary") - upload_file_state = gr.Textbox(label="Upload State") + upload_file_state_df = gr.DataFrame( + label="Upload Status Info", visible=False + ) + upload_file_state = gr.Textbox(label="Upload Status", visible=False) with gr.Tab("Directory"): upload_file_dir = gr.File( label="Upload a knowledge directory.", file_count="directory", ) upload_dir_btn = gr.Button("Upload", variant="primary") - upload_dir_state = gr.Textbox(label="Upload State") + upload_dir_state_df = gr.DataFrame( + label="Upload Status Info", visible=False + ) + upload_dir_state = gr.Textbox(label="Upload Status", visible=False) upload_file_btn.click( fn=upload_knowledge, inputs=[ @@ -63,7 +94,7 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap, enable_qa_extraction, ], - outputs=upload_file_state, + outputs=[upload_file_state_df, upload_file_state], api_name="upload_knowledge", ) upload_dir_btn.click( @@ -74,7 +105,7 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap, enable_qa_extraction, ], - outputs=upload_dir_state, + outputs=[upload_dir_state_df, upload_dir_state], api_name="upload_knowledge_dir", ) return { diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 8a842ca6..20955fab 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -58,8 +58,8 @@ def reload(self, config): self.logger.info("RagApplication reloaded successfully.") # TODO: 大量文件上传实现异步添加 - async def load_knowledge(self, file_dir, enable_qa_extraction=False): - await self.data_loader.load(file_dir, enable_qa_extraction) + def load_knowledge(self, file_dir, enable_qa_extraction=False): + self.data_loader.load(file_dir, enable_qa_extraction) async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index a7f0fc7d..43e7819c 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -11,7 +11,7 @@ ) from pai_rag.app.web.view_model import view_model from openinference.instrumentation import using_attributes -from typing import Any +from typing import Any, Dict def trace_correlation_id(function): @@ -41,14 +41,25 @@ def initialize(self, config_file: str): view_model.sync_app_config(self.rag_configuration.get_value()) self.rag = RagApplication() self.rag.initialize(self.rag_configuration.get_value()) + self.tasks_status: Dict[str, str] = {} def reload(self, new_config: Any): self.rag_configuration.update(new_config) self.rag.reload(self.rag_configuration.get_value()) self.rag_configuration.persist() - async def add_knowledge(self, file_dir: str, enable_qa_extraction: bool = False): - await self.rag.load_knowledge(file_dir, enable_qa_extraction) + def add_knowledge_async( + self, task_id: str, file_dir: str, enable_qa_extraction: bool = False + ): + self.tasks_status[task_id] = "processing" + try: + self.rag.load_knowledge(file_dir, enable_qa_extraction) + self.tasks_status[task_id] = "completed" + except Exception: + self.tasks_status[task_id] = "failed" + + def get_task_status(self, task_id: str) -> str: + return self.tasks_status.get(task_id, "unknown") async def aquery(self, query: RagQuery) -> RagResponse: return await self.rag.aquery(query) diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index a4d3c5a1..2f2f96f5 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -1,5 +1,7 @@ import os from typing import Any, Dict +import asyncio +import nest_asyncio from llama_index.core import Settings from llama_index.core.schema import TextNode from llama_index.llms.huggingface import HuggingFaceLLM @@ -57,7 +59,7 @@ def _extract_file_type(self, metadata: Dict[str, Any]): file_name = metadata.get("file_name", "dummy.txt") return os.path.splitext(file_name)[1] - async def load(self, file_directory: str, enable_qa_extraction: bool): + async def aload(self, file_directory: str, enable_qa_extraction: bool): data_reader = self.datareader_factory.get_reader(file_directory) docs = data_reader.load_data() logger.info(f"[DataReader] Loaded {len(docs)} docs.") @@ -113,3 +115,10 @@ async def load(self, file_directory: str, enable_qa_extraction: bool): self.index.storage_context.persist(persist_dir=store_path.persist_path) logger.info(f"Inserted {len(nodes)} nodes successfully.") return + + nest_asyncio.apply() # 应用嵌套补丁到事件循环 + + def load(self, file_directory: str, enable_qa_extraction: bool): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.aload(file_directory, enable_qa_extraction)) + return diff --git a/src/pai_rag/modules/embedding/embedding.py b/src/pai_rag/modules/embedding/embedding.py index 8b8350aa..75efbb93 100644 --- a/src/pai_rag/modules/embedding/embedding.py +++ b/src/pai_rag/modules/embedding/embedding.py @@ -53,7 +53,8 @@ def _create_new_instance(self, new_params: Dict[str, Any]): model_path = os.path.join(model_dir, model_name) embed_model = HuggingFaceEmbedding( - model_name=model_path, embed_batch_size=embed_batch_size + model_name=model_path, + embed_batch_size=embed_batch_size, ) logger.info( diff --git a/tests/core/test_rag_application.py b/tests/core/test_rag_application.py index 0e42a291..457d599c 100644 --- a/tests/core/test_rag_application.py +++ b/tests/core/test_rag_application.py @@ -28,10 +28,10 @@ def rag_app(): # Test load knowledge file -async def test_add_knowledge_file(rag_app: RagApplication): +def test_add_knowledge_file(rag_app: RagApplication): data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham") print(len(rag_app.index.docstore.docs)) - await rag_app.load_knowledge(data_dir) + rag_app.load_knowledge(data_dir) print(len(rag_app.index.docstore.docs)) assert len(rag_app.index.docstore.docs) > 0