Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove local storage and enable Elasticsearch hybrid query mode #60

Merged
merged 21 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from typing import List, Dict


class VectorDbConfig(BaseModel):
faiss_path: str | None = None


class RagQuery(BaseModel):
question: str
temperature: float | None = 0.1
vector_topk: int | None = 3
score_threshold: float | None = 0.5
chat_history: List[Dict[str, str]] | None = None
session_id: str | None = None
vector_db: VectorDbConfig | None = None


class LlmQuery(BaseModel):
Expand All @@ -20,8 +23,7 @@ class LlmQuery(BaseModel):

class RetrievalQuery(BaseModel):
question: str
topk: int | None = 3
score_threshold: float | None = 0.5
vector_db: VectorDbConfig | None = None


class RagResponse(BaseModel):
Expand Down
76 changes: 43 additions & 33 deletions src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry
from pai_rag.evaluations.batch_evaluator import BatchEvaluator
from pai_rag.app.api.models import (
Expand All @@ -24,49 +22,34 @@ def uuid_generator() -> str:
class RagApplication:
def __init__(self):
self.name = "RagApplication"
logging.basicConfig(level=logging.INFO) # 将日志级别设置为INFO
self.logger = logging.getLogger(__name__)

def initialize(self, config):
self.config = config

module_registry.init_modules(self.config)
self.index = module_registry.get_module("IndexModule")
self.llm = module_registry.get_module("LlmModule")
self.retriever = module_registry.get_module("RetrieverModule")
self.chat_store = module_registry.get_module("ChatStoreModule")
self.query_engine = module_registry.get_module("QueryEngineModule")
self.chat_engine_factory = module_registry.get_module("ChatEngineFactoryModule")
self.llm_chat_engine_factory = module_registry.get_module(
"LlmChatEngineFactoryModule"
)
self.data_reader_factory = module_registry.get_module("DataReaderFactoryModule")
self.agent = module_registry.get_module("AgentModule")

oss_cache = None
if config.get("oss_cache", None):
oss_cache = OssCache(config.oss_cache)
node_parser = module_registry.get_module("NodeParserModule")

self.data_loader = RagDataLoader(
self.data_reader_factory, node_parser, self.index, oss_cache
)
self.logger.info("RagApplication initialized successfully.")

def reload(self, config):
self.initialize(config)
self.logger.info("RagApplication reloaded successfully.")

# TODO: 大量文件上传实现异步添加
def load_knowledge(self, file_dir, enable_qa_extraction=False):
self.data_loader.load(file_dir, enable_qa_extraction)
async def load_knowledge(self, file_dir, enable_qa_extraction=False):
data_loader = module_registry.get_module_with_config(
"DataLoaderModule", self.config
)
await data_loader.aload(file_dir, enable_qa_extraction)

async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse:
if not query.question:
return RetrievalResponse(docs=[])

query_bundle = QueryBundle(query.question)
node_results = await self.query_engine.aretrieve(query_bundle)

query_engine = module_registry.get_module_with_config(
"QueryEngineModule", self.config
)
node_results = await query_engine.aretrieve(query_bundle)

docs = [
ContextDoc(
Expand Down Expand Up @@ -96,11 +79,24 @@ async def aquery(self, query: RagQuery) -> RagResponse:
answer="Empty query. Please input your question.", session_id=session_id
)

query_chat_engine = self.chat_engine_factory.get_chat_engine(
sessioned_config = self.config
if query.vector_db and query.vector_db.faiss_path:
sessioned_config = self.config.copy()
sessioned_config.index.update({"persist_path": query.vector_db.faiss_path})
print(sessioned_config)

chat_engine_factory = module_registry.get_module_with_config(
"ChatEngineFactoryModule", sessioned_config
)
query_chat_engine = chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await query_chat_engine.achat(query.question)
self.chat_store.persist()

chat_store = module_registry.get_module_with_config(
"ChatStoreModule", sessioned_config
)
chat_store.persist()
return RagResponse(answer=response.response, session_id=session_id)

async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
Expand All @@ -122,11 +118,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse:
answer="Empty query. Please input your question.", session_id=session_id
)

llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine(
llm_chat_engine_factory = module_registry.get_module_with_config(
"LlmChatEngineFactoryModule", self.config
)
llm_chat_engine = llm_chat_engine_factory.get_chat_engine(
session_id, query.chat_history
)
response = await llm_chat_engine.achat(query.question)
self.chat_store.persist()

chat_store = module_registry.get_module_with_config(
"ChatStoreModule", self.config
)
chat_store.persist()
return LlmResponse(answer=response.response, session_id=session_id)

async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
Expand All @@ -143,11 +146,18 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse:
if not query.question:
return LlmResponse(answer="Empty query. Please input your question.")

response = await self.agent.achat(query.question)
agent = module_registry.get_module_with_config("AgentModule", self.config)
response = await agent.achat(query.question)
return LlmResponse(answer=response.response)

async def batch_evaluate_retrieval_and_response(self, type):
batch_eval = BatchEvaluator(self.config, self.retriever, self.query_engine)
retriever = module_registry.get_module_with_config(
"RetrieverModule", self.config
)
query_engine = module_registry.get_module_with_config(
"QueryEngineModule", self.config
)
batch_eval = BatchEvaluator(self.config, retriever, query_engine)
df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation(
type=type, workers=2, save_to_file=True
)
Expand Down
10 changes: 7 additions & 3 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from pai_rag.app.web.view_model import view_model
from openinference.instrumentation import using_attributes
from typing import Any, Dict
import logging

logger = logging.getLogger(__name__)


def trace_correlation_id(function):
Expand Down Expand Up @@ -48,14 +51,15 @@ def reload(self, new_config: Any):
self.rag.reload(self.rag_configuration.get_value())
self.rag_configuration.persist()

def add_knowledge_async(
async def add_knowledge_async(
self, task_id: str, file_dir: str, enable_qa_extraction: bool = False
):
self.tasks_status[task_id] = "processing"
try:
self.rag.load_knowledge(file_dir, enable_qa_extraction)
await self.rag.load_knowledge(file_dir, enable_qa_extraction)
self.tasks_status[task_id] = "completed"
except Exception:
except Exception as ex:
logger.error(f"Upload failed: {ex}")
self.tasks_status[task_id] = "failed"

def get_task_status(self, task_id: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/data/rag_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
):
self.datareader_factory = datareader_factory
self.node_parser = node_parser
self.index = index
self.oss_cache = oss_cache
self.index = index

if use_local_qa_model:
# API暂不支持此选项
Expand Down Expand Up @@ -111,7 +111,7 @@ async def aload(self, file_directory: str, enable_qa_extraction: bool):

logger.info("[DataReader] Start inserting to index.")

self.index.insert_nodes(nodes)
await self.index.insert_nodes_async(nodes)
self.index.storage_context.persist(persist_dir=store_path.persist_path)
logger.info(f"Inserted {len(nodes)} nodes successfully.")
return
Expand Down
15 changes: 2 additions & 13 deletions src/pai_rag/data/rag_datapipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@
import click
import os
from pathlib import Path
from pai_rag.data.rag_dataloader import RagDataLoader
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.module_registry import module_registry


class RagDataPipeline:
def __init__(self, data_loader: RagDataLoader):
def __init__(self, data_loader):
self.data_loader = data_loader

async def ingest_from_folder(self, folder_path: str, enable_qa_extraction: bool):
Expand All @@ -23,16 +21,7 @@ def __init_data_pipeline(use_local_qa_model):
config = RagConfiguration.from_file(config_file).get_value()
module_registry.init_modules(config)

oss_cache = None
if config.get("oss_cache", None):
oss_cache = OssCache(config.oss_cache)
node_parser = module_registry.get_module("NodeParserModule")
index = module_registry.get_module("IndexModule")
data_reader_factory = module_registry.get_module("DataReaderFactoryModule")

data_loader = RagDataLoader(
data_reader_factory, node_parser, index, oss_cache, use_local_qa_model
)
data_loader = module_registry.get_module_with_config("DataLoaderModule", config)
return RagDataPipeline(data_loader)


Expand Down
4 changes: 2 additions & 2 deletions src/pai_rag/evaluations/batch_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def __init_evaluator_pipeline():
config = RagConfiguration.from_file(config_file).get_value()
module_registry.init_modules(config)

retriever = module_registry.get_module("RetrieverModule")
query_engine = module_registry.get_module("QueryEngineModule")
retriever = module_registry.get_module_with_config("RetrieverModule", config)
query_engine = module_registry.get_module_with_config("QueryEngineModule", config)

return BatchEvaluator(config, retriever, query_engine)

Expand Down
23 changes: 20 additions & 3 deletions src/pai_rag/evaluations/dataset_generation/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from pathlib import Path
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.modules.module_registry import module_registry
from llama_index.core.prompts.prompt_type import PromptType
Expand All @@ -16,8 +17,13 @@
DEFAULT_TEXT_QA_PROMPT_TMPL,
DEFAULT_QUESTION_GENERATION_QUERY,
)

import json

_BASE_DIR = Path(__file__).parent.parent.parent
DEFAULT_EVAL_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml")
DEFAULT_EVAL_DATA_FOLDER = "tests/testdata/paul_graham"


class GenerateDatasetPipeline(ModifiedRagDatasetGenerator):
def __init__(
Expand All @@ -29,11 +35,22 @@ def __init__(
show_progress: Optional[bool] = True,
) -> None:
self.name = "GenerateDatasetPipeline"
self.nodes = list(
module_registry.get_module("IndexModule").docstore.docs.values()
self.config = RagConfiguration.from_file(DEFAULT_EVAL_CONFIG_FILE).get_value()

# load nodes
module_registry.init_modules(self.config)
datareader_factory = module_registry.get_module_with_config(
"DataReaderFactoryModule", self.config
)
self.node_parser = module_registry.get_module_with_config(
"NodeParserModule", self.config
)
reader = datareader_factory.get_reader(DEFAULT_EVAL_DATA_FOLDER)
docs = reader.load_data()
self.nodes = self.node_parser.get_nodes_from_documents(docs)

self.num_questions_per_chunk = num_questions_per_chunk
self.llm = module_registry.get_module("LlmModule")
self.llm = module_registry.get_module_with_config("LlmModule", self.config)
self.text_question_template = PromptTemplate(text_question_template_str)
self.text_qa_template = PromptTemplate(
text_qa_template_str, prompt_type=PromptType.QUESTION_ANSWER
Expand Down
5 changes: 5 additions & 0 deletions src/pai_rag/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pai_rag.modules.embedding.embedding import EmbeddingModule
from pai_rag.modules.llm.llm_module import LlmModule
from pai_rag.modules.datareader.data_loader import DataLoaderModule
from pai_rag.modules.datareader.datareader_factory import DataReaderFactoryModule
from pai_rag.modules.index.index import IndexModule
from pai_rag.modules.nodeparser.node_parser import NodeParserModule
Expand All @@ -12,10 +13,13 @@
from pai_rag.modules.chat.chat_store import ChatStoreModule
from pai_rag.modules.agent.agent import AgentModule
from pai_rag.modules.tool.tool import ToolModule
from pai_rag.modules.cache.oss_cache import OssCacheModule


ALL_MODULES = [
"EmbeddingModule",
"LlmModule",
"DataLoaderModule",
"DataReaderFactoryModule",
"IndexModule",
"NodeParserModule",
Expand All @@ -28,6 +32,7 @@
"LlmChatEngineFactoryModule",
"AgentModule",
"ToolModule",
"OssCacheModule",
]

__all__ = ALL_MODULES + ["ALL_MODULES"]
26 changes: 5 additions & 21 deletions src/pai_rag/modules/base/configurable_module.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any
import logging

DEFAULT_INSTANCE_KEY = "__DEFAULT_INSTANCE__"


logger = logging.getLogger(__name__)


class ConfigurableModule(ABC):
"""Configurable Module

Helps to create instances according to configuration.
"""

def __init__(self):
self.__params_map = {}
self.__instance_map = {}

@abstractmethod
def _create_new_instance(self, new_params: Dict[str, Any]):
raise NotImplementedError
Expand All @@ -24,20 +24,4 @@ def get_dependencies() -> List[str]:
raise NotImplementedError

def get_or_create(self, new_params: Dict[str, Any]):
return self.get_or_create_by_name(new_params=new_params)

def get_or_create_by_name(
self, new_params: Dict[str, Any], name: str = DEFAULT_INSTANCE_KEY
):
# Create new instance when initializing or config changed.
if (
self.__params_map.get(name, None) is None
or self.__params_map[name] != new_params
):
print(f"{self.__class__.__name__} param changed, updating")
self.__instance_map[name] = self._create_new_instance(new_params)
self.__params_map[name] = new_params
else:
print(f"{self.__class__.__name__} param unchanged, skipping")

return self.__instance_map[name]
return self._create_new_instance(new_params)
20 changes: 20 additions & 0 deletions src/pai_rag/modules/cache/oss_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any, Dict, List
from pai_rag.utils.oss_cache import OssCache
from pai_rag.modules.base.configurable_module import ConfigurableModule
from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG
import logging

logger = logging.getLogger(__name__)


class OssCacheModule(ConfigurableModule):
@staticmethod
def get_dependencies() -> List[str]:
return []

def _create_new_instance(self, new_params: Dict[str, Any]):
cache_config = new_params[MODULE_PARAM_CONFIG]
if cache_config:
return OssCache(cache_config)
else:
return None
Loading
Loading