Skip to content

Commit

Permalink
Bugfix: connection error for longtime upload tasks (#62)
Browse files Browse the repository at this point in the history
* Fix connection error for longtime job

* fix testcase bugs

* support num workers for embedding model

* Refactor query api and add dataframe UI

* Refactor query api

* Remove embedding workers
  • Loading branch information
wwxxzz authored Jun 13, 2024
1 parent b492801 commit ef7090b
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 27 deletions.
23 changes: 17 additions & 6 deletions src/pai_rag/app/api/query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any
from fastapi import APIRouter, Body
from fastapi import APIRouter, Body, BackgroundTasks
import uuid
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api.models import (
RagQuery,
Expand Down Expand Up @@ -39,12 +40,22 @@ async def aupdate(new_config: Any = Body(None)):
return {"msg": "Update RAG configuration successfully."}


@router.post("/data")
async def load_data(input: DataInput):
await rag_service.add_knowledge(
file_dir=input.file_path, enable_qa_extraction=input.enable_qa_extraction
@router.post("/upload_data")
async def load_data(input: DataInput, background_tasks: BackgroundTasks):
task_id = uuid.uuid4().hex
background_tasks.add_task(
rag_service.add_knowledge_async,
task_id=task_id,
file_dir=input.file_path,
enable_qa_extraction=input.enable_qa_extraction,
)
return {"msg": "Update RAG configuration successfully."}
return {"task_id": task_id}


@router.get("/get_upload_state")
def task_status(task_id: str):
status = rag_service.get_task_status(task_id)
return {"task_id": task_id, "status": status}


@router.post("/evaluate/response")
Expand Down
15 changes: 13 additions & 2 deletions src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def config_url(self):

@property
def load_data_url(self):
return f"{self.endpoint}service/data"
return f"{self.endpoint}service/upload_data"

@property
def get_load_state_url(self):
return f"{self.endpoint}service/get_upload_state"

def query(self, text: str, session_id: str = None):
q = dict(question=text, session_id=session_id)
Expand Down Expand Up @@ -88,7 +92,14 @@ def add_knowledge(self, file_dir: str, enable_qa_extraction: bool):
q = dict(file_path=file_dir, enable_qa_extraction=enable_qa_extraction)
r = requests.post(self.load_data_url, json=q)
r.raise_for_status()
return
response = dotdict(json.loads(r.text))
return response

def get_knowledge_state(self, task_id: str):
r = requests.get(self.get_load_state_url, params={"task_id": task_id})
r.raise_for_status()
response = dotdict(json.loads(r.text))
return response

def reload_config(self, config: Any):
global cache_config
Expand Down
51 changes: 41 additions & 10 deletions src/pai_rag/app/web/tabs/upload_tab.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from typing import Dict, Any
import gradio as gr
import time
from pai_rag.app.web.rag_client import rag_client
from pai_rag.app.web.view_model import view_model
from pai_rag.utils.file_utils import MyUploadFile
import pandas as pd


def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extraction):
Expand All @@ -14,14 +17,36 @@ def upload_knowledge(upload_files, chunk_size, chunk_overlap, enable_qa_extracti
if not upload_files:
return "No file selected. Please choose at least one file."

my_upload_files = []
for file in upload_files:
file_dir = os.path.dirname(file.name)
rag_client.add_knowledge(file_dir, enable_qa_extraction)
return (
"Upload "
+ str(len(upload_files))
+ " files Success! \n \n Relevant content has been added to the vector store, you can now start chatting and asking questions."
)
response = rag_client.add_knowledge(file_dir, enable_qa_extraction)
my_upload_files.append(
MyUploadFile(os.path.basename(file.name), response["task_id"])
)

result = {"Info": ["StartTime", "EndTime", "Duration(s)", "Status"]}
while not all(file.finished is True for file in my_upload_files):
for file in my_upload_files:
response = rag_client.get_knowledge_state(str(file.task_id))
file.update_state(response["status"])
file.update_process_duration()
result[file.file_name] = file.__info__()
if response["status"] in ["completed", "failed"]:
file.is_finished()
yield [
gr.update(visible=True, value=pd.DataFrame(result)),
gr.update(visible=False),
]
time.sleep(2)

yield [
gr.update(visible=True, value=pd.DataFrame(result)),
gr.update(
visible=True,
value="Uploaded all files successfully! \n Relevant content has been added to the vector store, you can now start chatting and asking questions.",
),
]


def create_upload_tab() -> Dict[str, Any]:
Expand All @@ -47,14 +72,20 @@ def create_upload_tab() -> Dict[str, Any]:
label="Upload a knowledge file.", file_count="multiple"
)
upload_file_btn = gr.Button("Upload", variant="primary")
upload_file_state = gr.Textbox(label="Upload State")
upload_file_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
upload_file_state = gr.Textbox(label="Upload Status", visible=False)
with gr.Tab("Directory"):
upload_file_dir = gr.File(
label="Upload a knowledge directory.",
file_count="directory",
)
upload_dir_btn = gr.Button("Upload", variant="primary")
upload_dir_state = gr.Textbox(label="Upload State")
upload_dir_state_df = gr.DataFrame(
label="Upload Status Info", visible=False
)
upload_dir_state = gr.Textbox(label="Upload Status", visible=False)
upload_file_btn.click(
fn=upload_knowledge,
inputs=[
Expand All @@ -63,7 +94,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_qa_extraction,
],
outputs=upload_file_state,
outputs=[upload_file_state_df, upload_file_state],
api_name="upload_knowledge",
)
upload_dir_btn.click(
Expand All @@ -74,7 +105,7 @@ def create_upload_tab() -> Dict[str, Any]:
chunk_overlap,
enable_qa_extraction,
],
outputs=upload_dir_state,
outputs=[upload_dir_state_df, upload_dir_state],
api_name="upload_knowledge_dir",
)
return {
Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def reload(self, config):
self.logger.info("RagApplication reloaded successfully.")

# TODO: 大量文件上传实现异步添加
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
await self.data_loader.load(file_dir, enable_qa_extraction)
def load_knowledge(self, file_dir, enable_qa_extraction=False):
self.data_loader.load(file_dir, enable_qa_extraction)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
Expand Down
17 changes: 14 additions & 3 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from pai_rag.app.web.view_model import view_model
from openinference.instrumentation import using_attributes
from typing import Any
from typing import Any, Dict


def trace_correlation_id(function):
Expand Down Expand Up @@ -41,14 +41,25 @@ def initialize(self, config_file: str):
view_model.sync_app_config(self.rag_configuration.get_value())
self.rag = RagApplication()
self.rag.initialize(self.rag_configuration.get_value())
self.tasks_status: Dict[str, str] = {}

def reload(self, new_config: Any):
self.rag_configuration.update(new_config)
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()

async def add_knowledge(self, file_dir: str, enable_qa_extraction: bool = False):
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
def add_knowledge_async(
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
):
self.tasks_status[task_id] = "processing"
try:
self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.tasks_status[task_id] = "completed"
except Exception:
self.tasks_status[task_id] = "failed"

def get_task_status(self, task_id: str) -> str:
return self.tasks_status.get(task_id, "unknown")

async def aquery(self, query: RagQuery) -> RagResponse:
return await self.rag.aquery(query)
Expand Down
11 changes: 10 additions & 1 deletion src/pai_rag/data/rag_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from typing import Any, Dict
import asyncio
import nest_asyncio
from llama_index.core import Settings
from llama_index.core.schema import TextNode
from llama_index.llms.huggingface import HuggingFaceLLM
Expand Down Expand Up @@ -57,7 +59,7 @@ def _extract_file_type(self, metadata: Dict[str, Any]):
file_name = metadata.get("file_name", "dummy.txt")
return os.path.splitext(file_name)[1]

async def load(self, file_directory: str, enable_qa_extraction: bool):
async def aload(self, file_directory: str, enable_qa_extraction: bool):
data_reader = self.datareader_factory.get_reader(file_directory)
docs = data_reader.load_data()
logger.info(f"[DataReader] Loaded {len(docs)} docs.")
Expand Down Expand Up @@ -113,3 +115,10 @@ async def load(self, file_directory: str, enable_qa_extraction: bool):
self.index.storage_context.persist(persist_dir=store_path.persist_path)
logger.info(f"Inserted {len(nodes)} nodes successfully.")
return

nest_asyncio.apply() # 应用嵌套补丁到事件循环

def load(self, file_directory: str, enable_qa_extraction: bool):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.aload(file_directory, enable_qa_extraction))
return
3 changes: 2 additions & 1 deletion src/pai_rag/modules/embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _create_new_instance(self, new_params: Dict[str, Any]):

model_path = os.path.join(model_dir, model_name)
embed_model = HuggingFaceEmbedding(
model_name=model_path, embed_batch_size=embed_batch_size
model_name=model_path,
embed_batch_size=embed_batch_size,
)

logger.info(
Expand Down
4 changes: 2 additions & 2 deletions tests/core/test_rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def rag_app():


# Test load knowledge file
async def test_add_knowledge_file(rag_app: RagApplication):
def test_add_knowledge_file(rag_app: RagApplication):
data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham")
print(len(rag_app.index.docstore.docs))
await rag_app.load_knowledge(data_dir)
rag_app.load_knowledge(data_dir)
print(len(rag_app.index.docstore.docs))
assert len(rag_app.index.docstore.docs) > 0

Expand Down

0 comments on commit ef7090b

Please sign in to comment.