diff --git a/src/pai_rag/app/api/models.py b/src/pai_rag/app/api/models.py index 8cfede1e..62faf36e 100644 --- a/src/pai_rag/app/api/models.py +++ b/src/pai_rag/app/api/models.py @@ -2,13 +2,16 @@ from typing import List, Dict +class VectorDbConfig(BaseModel): + faiss_path: str | None = None + + class RagQuery(BaseModel): question: str temperature: float | None = 0.1 - vector_topk: int | None = 3 - score_threshold: float | None = 0.5 chat_history: List[Dict[str, str]] | None = None session_id: str | None = None + vector_db: VectorDbConfig | None = None class LlmQuery(BaseModel): @@ -20,8 +23,7 @@ class LlmQuery(BaseModel): class RetrievalQuery(BaseModel): question: str - topk: int | None = 3 - score_threshold: float | None = 0.5 + vector_db: VectorDbConfig | None = None class RagResponse(BaseModel): diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 20955fab..cc314882 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -1,5 +1,3 @@ -from pai_rag.data.rag_dataloader import RagDataLoader -from pai_rag.utils.oss_cache import OssCache from pai_rag.modules.module_registry import module_registry from pai_rag.evaluations.batch_evaluator import BatchEvaluator from pai_rag.app.api.models import ( @@ -24,33 +22,11 @@ def uuid_generator() -> str: class RagApplication: def __init__(self): self.name = "RagApplication" - logging.basicConfig(level=logging.INFO) # 将日志级别设置为INFO self.logger = logging.getLogger(__name__) def initialize(self, config): self.config = config - module_registry.init_modules(self.config) - self.index = module_registry.get_module("IndexModule") - self.llm = module_registry.get_module("LlmModule") - self.retriever = module_registry.get_module("RetrieverModule") - self.chat_store = module_registry.get_module("ChatStoreModule") - self.query_engine = module_registry.get_module("QueryEngineModule") - self.chat_engine_factory = module_registry.get_module("ChatEngineFactoryModule") - self.llm_chat_engine_factory = module_registry.get_module( - "LlmChatEngineFactoryModule" - ) - self.data_reader_factory = module_registry.get_module("DataReaderFactoryModule") - self.agent = module_registry.get_module("AgentModule") - - oss_cache = None - if config.get("oss_cache", None): - oss_cache = OssCache(config.oss_cache) - node_parser = module_registry.get_module("NodeParserModule") - - self.data_loader = RagDataLoader( - self.data_reader_factory, node_parser, self.index, oss_cache - ) self.logger.info("RagApplication initialized successfully.") def reload(self, config): @@ -58,15 +34,22 @@ def reload(self, config): self.logger.info("RagApplication reloaded successfully.") # TODO: 大量文件上传实现异步添加 - def load_knowledge(self, file_dir, enable_qa_extraction=False): - self.data_loader.load(file_dir, enable_qa_extraction) + async def load_knowledge(self, file_dir, enable_qa_extraction=False): + data_loader = module_registry.get_module_with_config( + "DataLoaderModule", self.config + ) + await data_loader.aload(file_dir, enable_qa_extraction) async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: return RetrievalResponse(docs=[]) query_bundle = QueryBundle(query.question) - node_results = await self.query_engine.aretrieve(query_bundle) + + query_engine = module_registry.get_module_with_config( + "QueryEngineModule", self.config + ) + node_results = await query_engine.aretrieve(query_bundle) docs = [ ContextDoc( @@ -96,11 +79,24 @@ async def aquery(self, query: RagQuery) -> RagResponse: answer="Empty query. Please input your question.", session_id=session_id ) - query_chat_engine = self.chat_engine_factory.get_chat_engine( + sessioned_config = self.config + if query.vector_db and query.vector_db.faiss_path: + sessioned_config = self.config.copy() + sessioned_config.index.update({"persist_path": query.vector_db.faiss_path}) + print(sessioned_config) + + chat_engine_factory = module_registry.get_module_with_config( + "ChatEngineFactoryModule", sessioned_config + ) + query_chat_engine = chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await query_chat_engine.achat(query.question) - self.chat_store.persist() + + chat_store = module_registry.get_module_with_config( + "ChatStoreModule", sessioned_config + ) + chat_store.persist() return RagResponse(answer=response.response, session_id=session_id) async def aquery_llm(self, query: LlmQuery) -> LlmResponse: @@ -122,11 +118,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse: answer="Empty query. Please input your question.", session_id=session_id ) - llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine( + llm_chat_engine_factory = module_registry.get_module_with_config( + "LlmChatEngineFactoryModule", self.config + ) + llm_chat_engine = llm_chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await llm_chat_engine.achat(query.question) - self.chat_store.persist() + + chat_store = module_registry.get_module_with_config( + "ChatStoreModule", self.config + ) + chat_store.persist() return LlmResponse(answer=response.response, session_id=session_id) async def aquery_agent(self, query: LlmQuery) -> LlmResponse: @@ -143,11 +146,18 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse: if not query.question: return LlmResponse(answer="Empty query. Please input your question.") - response = await self.agent.achat(query.question) + agent = module_registry.get_module_with_config("AgentModule", self.config) + response = await agent.achat(query.question) return LlmResponse(answer=response.response) async def batch_evaluate_retrieval_and_response(self, type): - batch_eval = BatchEvaluator(self.config, self.retriever, self.query_engine) + retriever = module_registry.get_module_with_config( + "RetrieverModule", self.config + ) + query_engine = module_registry.get_module_with_config( + "QueryEngineModule", self.config + ) + batch_eval = BatchEvaluator(self.config, retriever, query_engine) df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation( type=type, workers=2, save_to_file=True ) diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 43e7819c..4e768752 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -12,6 +12,9 @@ from pai_rag.app.web.view_model import view_model from openinference.instrumentation import using_attributes from typing import Any, Dict +import logging + +logger = logging.getLogger(__name__) def trace_correlation_id(function): @@ -48,14 +51,15 @@ def reload(self, new_config: Any): self.rag.reload(self.rag_configuration.get_value()) self.rag_configuration.persist() - def add_knowledge_async( + async def add_knowledge_async( self, task_id: str, file_dir: str, enable_qa_extraction: bool = False ): self.tasks_status[task_id] = "processing" try: - self.rag.load_knowledge(file_dir, enable_qa_extraction) + await self.rag.load_knowledge(file_dir, enable_qa_extraction) self.tasks_status[task_id] = "completed" - except Exception: + except Exception as ex: + logger.error(f"Upload failed: {ex}") self.tasks_status[task_id] = "failed" def get_task_status(self, task_id: str) -> str: diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index 2f2f96f5..01246ce1 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -37,8 +37,8 @@ def __init__( ): self.datareader_factory = datareader_factory self.node_parser = node_parser - self.index = index self.oss_cache = oss_cache + self.index = index if use_local_qa_model: # API暂不支持此选项 @@ -111,7 +111,7 @@ async def aload(self, file_directory: str, enable_qa_extraction: bool): logger.info("[DataReader] Start inserting to index.") - self.index.insert_nodes(nodes) + await self.index.insert_nodes_async(nodes) self.index.storage_context.persist(persist_dir=store_path.persist_path) logger.info(f"Inserted {len(nodes)} nodes successfully.") return diff --git a/src/pai_rag/data/rag_datapipeline.py b/src/pai_rag/data/rag_datapipeline.py index feb0f750..23787e3b 100644 --- a/src/pai_rag/data/rag_datapipeline.py +++ b/src/pai_rag/data/rag_datapipeline.py @@ -2,14 +2,12 @@ import click import os from pathlib import Path -from pai_rag.data.rag_dataloader import RagDataLoader from pai_rag.core.rag_configuration import RagConfiguration -from pai_rag.utils.oss_cache import OssCache from pai_rag.modules.module_registry import module_registry class RagDataPipeline: - def __init__(self, data_loader: RagDataLoader): + def __init__(self, data_loader): self.data_loader = data_loader async def ingest_from_folder(self, folder_path: str, enable_qa_extraction: bool): @@ -23,16 +21,7 @@ def __init_data_pipeline(use_local_qa_model): config = RagConfiguration.from_file(config_file).get_value() module_registry.init_modules(config) - oss_cache = None - if config.get("oss_cache", None): - oss_cache = OssCache(config.oss_cache) - node_parser = module_registry.get_module("NodeParserModule") - index = module_registry.get_module("IndexModule") - data_reader_factory = module_registry.get_module("DataReaderFactoryModule") - - data_loader = RagDataLoader( - data_reader_factory, node_parser, index, oss_cache, use_local_qa_model - ) + data_loader = module_registry.get_module_with_config("DataLoaderModule", config) return RagDataPipeline(data_loader) diff --git a/src/pai_rag/evaluations/batch_evaluator.py b/src/pai_rag/evaluations/batch_evaluator.py index 47684e25..abd5f142 100644 --- a/src/pai_rag/evaluations/batch_evaluator.py +++ b/src/pai_rag/evaluations/batch_evaluator.py @@ -189,8 +189,8 @@ def __init_evaluator_pipeline(): config = RagConfiguration.from_file(config_file).get_value() module_registry.init_modules(config) - retriever = module_registry.get_module("RetrieverModule") - query_engine = module_registry.get_module("QueryEngineModule") + retriever = module_registry.get_module_with_config("RetrieverModule", config) + query_engine = module_registry.get_module_with_config("QueryEngineModule", config) return BatchEvaluator(config, retriever, query_engine) diff --git a/src/pai_rag/evaluations/dataset_generation/generate_dataset.py b/src/pai_rag/evaluations/dataset_generation/generate_dataset.py index 3bfcddc0..02b43e60 100644 --- a/src/pai_rag/evaluations/dataset_generation/generate_dataset.py +++ b/src/pai_rag/evaluations/dataset_generation/generate_dataset.py @@ -1,5 +1,6 @@ import logging import os +from pathlib import Path from pai_rag.core.rag_configuration import RagConfiguration from pai_rag.modules.module_registry import module_registry from llama_index.core.prompts.prompt_type import PromptType @@ -16,8 +17,13 @@ DEFAULT_TEXT_QA_PROMPT_TMPL, DEFAULT_QUESTION_GENERATION_QUERY, ) + import json +_BASE_DIR = Path(__file__).parent.parent.parent +DEFAULT_EVAL_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") +DEFAULT_EVAL_DATA_FOLDER = "tests/testdata/paul_graham" + class GenerateDatasetPipeline(ModifiedRagDatasetGenerator): def __init__( @@ -29,11 +35,22 @@ def __init__( show_progress: Optional[bool] = True, ) -> None: self.name = "GenerateDatasetPipeline" - self.nodes = list( - module_registry.get_module("IndexModule").docstore.docs.values() + self.config = RagConfiguration.from_file(DEFAULT_EVAL_CONFIG_FILE).get_value() + + # load nodes + module_registry.init_modules(self.config) + datareader_factory = module_registry.get_module_with_config( + "DataReaderFactoryModule", self.config ) + self.node_parser = module_registry.get_module_with_config( + "NodeParserModule", self.config + ) + reader = datareader_factory.get_reader(DEFAULT_EVAL_DATA_FOLDER) + docs = reader.load_data() + self.nodes = self.node_parser.get_nodes_from_documents(docs) + self.num_questions_per_chunk = num_questions_per_chunk - self.llm = module_registry.get_module("LlmModule") + self.llm = module_registry.get_module_with_config("LlmModule", self.config) self.text_question_template = PromptTemplate(text_question_template_str) self.text_qa_template = PromptTemplate( text_qa_template_str, prompt_type=PromptType.QUESTION_ANSWER diff --git a/src/pai_rag/modules/__init__.py b/src/pai_rag/modules/__init__.py index a23c466b..3491bb86 100644 --- a/src/pai_rag/modules/__init__.py +++ b/src/pai_rag/modules/__init__.py @@ -1,5 +1,6 @@ from pai_rag.modules.embedding.embedding import EmbeddingModule from pai_rag.modules.llm.llm_module import LlmModule +from pai_rag.modules.datareader.data_loader import DataLoaderModule from pai_rag.modules.datareader.datareader_factory import DataReaderFactoryModule from pai_rag.modules.index.index import IndexModule from pai_rag.modules.nodeparser.node_parser import NodeParserModule @@ -12,10 +13,13 @@ from pai_rag.modules.chat.chat_store import ChatStoreModule from pai_rag.modules.agent.agent import AgentModule from pai_rag.modules.tool.tool import ToolModule +from pai_rag.modules.cache.oss_cache import OssCacheModule + ALL_MODULES = [ "EmbeddingModule", "LlmModule", + "DataLoaderModule", "DataReaderFactoryModule", "IndexModule", "NodeParserModule", @@ -28,6 +32,7 @@ "LlmChatEngineFactoryModule", "AgentModule", "ToolModule", + "OssCacheModule", ] __all__ = ALL_MODULES + ["ALL_MODULES"] diff --git a/src/pai_rag/modules/base/configurable_module.py b/src/pai_rag/modules/base/configurable_module.py index 78e9b089..de9e3834 100644 --- a/src/pai_rag/modules/base/configurable_module.py +++ b/src/pai_rag/modules/base/configurable_module.py @@ -1,19 +1,19 @@ from abc import ABC, abstractmethod from typing import Dict, List, Any +import logging DEFAULT_INSTANCE_KEY = "__DEFAULT_INSTANCE__" +logger = logging.getLogger(__name__) + + class ConfigurableModule(ABC): """Configurable Module Helps to create instances according to configuration. """ - def __init__(self): - self.__params_map = {} - self.__instance_map = {} - @abstractmethod def _create_new_instance(self, new_params: Dict[str, Any]): raise NotImplementedError @@ -24,20 +24,4 @@ def get_dependencies() -> List[str]: raise NotImplementedError def get_or_create(self, new_params: Dict[str, Any]): - return self.get_or_create_by_name(new_params=new_params) - - def get_or_create_by_name( - self, new_params: Dict[str, Any], name: str = DEFAULT_INSTANCE_KEY - ): - # Create new instance when initializing or config changed. - if ( - self.__params_map.get(name, None) is None - or self.__params_map[name] != new_params - ): - print(f"{self.__class__.__name__} param changed, updating") - self.__instance_map[name] = self._create_new_instance(new_params) - self.__params_map[name] = new_params - else: - print(f"{self.__class__.__name__} param unchanged, skipping") - - return self.__instance_map[name] + return self._create_new_instance(new_params) diff --git a/src/pai_rag/modules/cache/oss_cache.py b/src/pai_rag/modules/cache/oss_cache.py new file mode 100644 index 00000000..d4a543b9 --- /dev/null +++ b/src/pai_rag/modules/cache/oss_cache.py @@ -0,0 +1,20 @@ +from typing import Any, Dict, List +from pai_rag.utils.oss_cache import OssCache +from pai_rag.modules.base.configurable_module import ConfigurableModule +from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG +import logging + +logger = logging.getLogger(__name__) + + +class OssCacheModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return [] + + def _create_new_instance(self, new_params: Dict[str, Any]): + cache_config = new_params[MODULE_PARAM_CONFIG] + if cache_config: + return OssCache(cache_config) + else: + return None diff --git a/src/pai_rag/modules/chat/chat_engine_factory.py b/src/pai_rag/modules/chat/chat_engine_factory.py index ea8eea29..69282019 100644 --- a/src/pai_rag/modules/chat/chat_engine_factory.py +++ b/src/pai_rag/modules/chat/chat_engine_factory.py @@ -18,17 +18,11 @@ logger = logging.getLogger(__name__) -class ChatEngineFactoryModule(ConfigurableModule): - @staticmethod - def get_dependencies() -> List[str]: - return ["QueryEngineModule", "ChatStoreModule"] - - def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.query_engine = new_params["QueryEngineModule"] - self.chat_store = new_params["ChatStoreModule"] - - return self +class ChatEngineFactory: + def __init__(self, chat_type, query_engine, chat_store): + self.chat_type = chat_type + self.query_engine = query_engine + self.chat_store = chat_store def get_chat_engine(self, session_id, chat_history): chat_memory = self.chat_store.get_chat_memory_buffer(session_id) @@ -36,7 +30,8 @@ def get_chat_engine(self, session_id, chat_history): history_messages = parse_chat_messages(chat_history) for hist_mes in history_messages: chat_memory.put(hist_mes) - if self.config.type == "CondenseQuestionChatEngine": + + if self.chat_type == "CondenseQuestionChatEngine": my_chat_engine = CondenseQuestionChatEngine.from_defaults( query_engine=self.query_engine, condense_question_prompt=CONDENSE_QUESTION_CHAT_ENGINE_PROMPT_ZH, @@ -47,3 +42,18 @@ def get_chat_engine(self, session_id, chat_history): return my_chat_engine else: raise ValueError(f"Unknown chat_engine_type: {self.config.type}") + + +class ChatEngineFactoryModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return ["QueryEngineModule", "ChatStoreModule"] + + def _create_new_instance(self, new_params: Dict[str, Any]): + config = new_params[MODULE_PARAM_CONFIG] + query_engine = new_params["QueryEngineModule"] + chat_store = new_params["ChatStoreModule"] + + return ChatEngineFactory( + config.type, query_engine=query_engine, chat_store=chat_store + ) diff --git a/src/pai_rag/modules/chat/llm_chat_engine_factory.py b/src/pai_rag/modules/chat/llm_chat_engine_factory.py index 815d0d3e..0b90cb67 100644 --- a/src/pai_rag/modules/chat/llm_chat_engine_factory.py +++ b/src/pai_rag/modules/chat/llm_chat_engine_factory.py @@ -12,16 +12,11 @@ logger = logging.getLogger(__name__) -class LlmChatEngineFactoryModule(ConfigurableModule): - @staticmethod - def get_dependencies() -> List[str]: - return ["LlmModule", "ChatStoreModule"] - - def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.llm = new_params["LlmModule"] - self.chat_store = new_params["ChatStoreModule"] - return self +class LlmChatEngineFactory: + def __init__(self, chat_type, llm, chat_store): + self.chat_type = chat_type + self.llm = llm + self.chat_store = chat_store def get_chat_engine(self, session_id, chat_history): chat_memory = self.chat_store.get_chat_memory_buffer(session_id) @@ -30,14 +25,26 @@ def get_chat_engine(self, session_id, chat_history): for hist_mes in history_messages: chat_memory.put(hist_mes) - if self.config.type == "SimpleChatEngine": + if self.chat_type == "SimpleChatEngine": my_chat_engine = SimpleChatEngine.from_defaults( llm=self.llm, memory=chat_memory, verbose=True, ) logger.info("simple chat_engine instance created") + + return my_chat_engine else: raise ValueError(f"Unknown chat_engine_type: {self.config.type}") - return my_chat_engine + +class LlmChatEngineFactoryModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return ["LlmModule", "ChatStoreModule"] + + def _create_new_instance(self, new_params: Dict[str, Any]): + config = new_params[MODULE_PARAM_CONFIG] + llm = new_params["LlmModule"] + chat_store = new_params["ChatStoreModule"] + return LlmChatEngineFactory(config.type, llm, chat_store) diff --git a/src/pai_rag/modules/datareader/data_loader.py b/src/pai_rag/modules/datareader/data_loader.py new file mode 100644 index 00000000..5e5e0383 --- /dev/null +++ b/src/pai_rag/modules/datareader/data_loader.py @@ -0,0 +1,25 @@ +from typing import Any, Dict, List +from pai_rag.modules.base.configurable_module import ConfigurableModule +from pai_rag.data.rag_dataloader import RagDataLoader +import logging + +logger = logging.getLogger(__name__) + + +class DataLoaderModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return [ + "OssCacheModule", + "DataReaderFactoryModule", + "NodeParserModule", + "IndexModule", + ] + + def _create_new_instance(self, new_params: Dict[str, Any]): + oss_cache = new_params["OssCacheModule"] + data_reader_factory = new_params["DataReaderFactoryModule"] + node_parser = new_params["NodeParserModule"] + index = new_params["IndexModule"] + + return RagDataLoader(data_reader_factory, node_parser, index, oss_cache) diff --git a/src/pai_rag/modules/embedding/embedding.py b/src/pai_rag/modules/embedding/embedding.py index 75efbb93..ed7eb24c 100644 --- a/src/pai_rag/modules/embedding/embedding.py +++ b/src/pai_rag/modules/embedding/embedding.py @@ -3,9 +3,9 @@ from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.embeddings.dashscope import DashScopeEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG +from pai_rag.modules.embedding.my_huggingface_embedding import MyHuggingFaceEmbedding from pai_rag.utils.constants import DEFAULT_MODEL_DIR import os import logging @@ -52,7 +52,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): model_name = config.get("model_name", DEFAULT_HUGGINGFACE_EMBEDDING_MODEL) model_path = os.path.join(model_dir, model_name) - embed_model = HuggingFaceEmbedding( + embed_model = MyHuggingFaceEmbedding( model_name=model_path, embed_batch_size=embed_batch_size, ) diff --git a/src/pai_rag/modules/embedding/my_huggingface_embedding.py b/src/pai_rag/modules/embedding/my_huggingface_embedding.py new file mode 100644 index 00000000..f21f8cbb --- /dev/null +++ b/src/pai_rag/modules/embedding/my_huggingface_embedding.py @@ -0,0 +1,143 @@ +import logging +from typing import Any, List, Optional + +from llama_index.core.base.embeddings.base import ( + DEFAULT_EMBED_BATCH_SIZE, + BaseEmbedding, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.utils import get_cache_dir, infer_torch_device +from llama_index.embeddings.huggingface.utils import ( + DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, + get_query_instruct_for_model_name, + get_text_instruct_for_model_name, +) +from sentence_transformers import SentenceTransformer + +DEFAULT_HUGGINGFACE_LENGTH = 512 +logger = logging.getLogger(__name__) + + +# Add async interfaces +class MyHuggingFaceEmbedding(BaseEmbedding): + max_length: int = Field( + default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0 + ) + normalize: bool = Field(default=True, description="Normalize embeddings or not.") + query_instruction: Optional[str] = Field( + description="Instruction to prepend to query text." + ) + text_instruction: Optional[str] = Field( + description="Instruction to prepend to text." + ) + cache_folder: Optional[str] = Field( + description="Cache folder for Hugging Face files." + ) + + _model: Any = PrivateAttr() + _device: str = PrivateAttr() + + def __init__( + self, + model_name: str = DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, + tokenizer_name: Optional[str] = "deprecated", + pooling: str = "deprecated", + max_length: Optional[int] = None, + query_instruction: Optional[str] = None, + text_instruction: Optional[str] = None, + normalize: bool = True, + model: Optional[Any] = "deprecated", + tokenizer: Optional[Any] = "deprecated", + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + cache_folder: Optional[str] = None, + trust_remote_code: bool = False, + device: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + **model_kwargs, + ): + self._device = device or infer_torch_device() + + cache_folder = cache_folder or get_cache_dir() + + for variable, value in [ + ("model", model), + ("tokenizer", tokenizer), + ("pooling", pooling), + ("tokenizer_name", tokenizer_name), + ]: + if value != "deprecated": + raise ValueError( + f"{variable} is deprecated. Please remove it from the arguments." + ) + if model_name is None: + raise ValueError("The `model_name` argument must be provided.") + + self._model = SentenceTransformer( + model_name, + device=self._device, + cache_folder=cache_folder, + trust_remote_code=trust_remote_code, + prompts={ + "query": query_instruction + or get_query_instruct_for_model_name(model_name), + "text": text_instruction + or get_text_instruct_for_model_name(model_name), + }, + **model_kwargs, + ) + if max_length: + self._model.max_seq_length = max_length + else: + max_length = self._model.max_seq_length + + super().__init__( + embed_batch_size=embed_batch_size, + callback_manager=callback_manager, + model_name=model_name, + max_length=max_length, + normalize=normalize, + query_instruction=query_instruction, + text_instruction=text_instruction, + ) + + @classmethod + def class_name(cls) -> str: + return "HuggingFaceEmbedding" + + def _embed( + self, + sentences: List[str], + prompt_name: Optional[str] = None, + ) -> List[List[float]]: + """Embed sentences.""" + return self._model.encode( + sentences, + batch_size=self.embed_batch_size, + prompt_name=prompt_name, + normalize_embeddings=self.normalize, + ).tolist() + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self._embed(query, prompt_name="query") + + async def _aget_query_embedding(self, query: str) -> List[float]: + """Get query embedding async.""" + return self._get_query_embedding(query) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Get text embedding async.""" + return self._get_text_embedding(text) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self._embed(text, prompt_name="text") + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + return self._embed(texts, prompt_name="text") + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + return self._embed(texts, prompt_name="text") diff --git a/src/pai_rag/modules/index/index.py b/src/pai_rag/modules/index/index.py index 248fbc96..84e22697 100644 --- a/src/pai_rag/modules/index/index.py +++ b/src/pai_rag/modules/index/index.py @@ -3,12 +3,12 @@ import sys from typing import Dict, List, Any -from llama_index.core import VectorStoreIndex - -from llama_index.core import load_index_from_storage +from pai_rag.modules.index.my_vector_store_index import MyVectorStoreIndex +from pai_rag.modules.index.index_utils import load_index_from_storage from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.modules.index.store import RagStore +from llama_index.vector_stores.faiss import FaissVectorStore from pai_rag.utils.store_utils import get_store_persist_directory_name, store_path logging.basicConfig( @@ -18,7 +18,6 @@ ) DEFAULT_PERSIST_DIR = "./storage" -INDEX_STATE_FILE = "index.state.json" class IndexModule(ConfigurableModule): @@ -33,43 +32,44 @@ class IndexModule(ConfigurableModule): def get_dependencies() -> List[str]: return ["EmbeddingModule"] - def _get_embed_vec_dim(self): + def _get_embed_vec_dim(self, embed_model): # Get dimension size of embedding vector - return len(self.embed_model._get_text_embedding("test")) + return len(embed_model._get_text_embedding("test")) def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.embed_model = new_params["EmbeddingModule"] - self.embed_dims = self._get_embed_vec_dim() - persist_path = self.config.get("persist_path", DEFAULT_PERSIST_DIR) - folder_name = get_store_persist_directory_name(self.config, self.embed_dims) + config = new_params[MODULE_PARAM_CONFIG] + embed_model = new_params["EmbeddingModule"] + embed_dims = self._get_embed_vec_dim(embed_model) + persist_path = config.get("persist_path", DEFAULT_PERSIST_DIR) + folder_name = get_store_persist_directory_name(config, embed_dims) store_path.persist_path = os.path.join(persist_path, folder_name) + is_empty = not os.path.exists(store_path.persist_path) + rag_store = RagStore(config, store_path.persist_path, is_empty, embed_dims) + storage_context = rag_store.get_storage_context() - self.is_empty = not os.path.exists(store_path.persist_path) - rag_store = RagStore( - self.config, store_path.persist_path, self.is_empty, self.embed_dims - ) - self.storage_context = rag_store.get_storage_context() - - if self.is_empty: - return self.create_indices() + if is_empty: + return self.create_indices(storage_context, embed_model) else: - return self.load_indices() + return self.load_indices(storage_context, embed_model) - def create_indices(self): + def create_indices(self, storage_context, embed_model): logging.info("Empty index, need to create indices.") - vector_index = VectorStoreIndex( - nodes=[], - storage_context=self.storage_context, - embed_model=self.embed_model, - store_nodes_override=True, + vector_index = MyVectorStoreIndex( + nodes=[], storage_context=storage_context, embed_model=embed_model ) logging.info("Created vector_index.") return vector_index - def load_indices(self): - vector_index = load_index_from_storage(storage_context=self.storage_context) - + def load_indices(self, storage_context, embed_model): + if isinstance(storage_context.vector_store, FaissVectorStore): + vector_index = load_index_from_storage(storage_context=storage_context) + return vector_index + else: + vector_index = MyVectorStoreIndex( + nodes=[], + storage_context=storage_context, + embed_model=embed_model, + ) return vector_index diff --git a/src/pai_rag/modules/index/index_utils.py b/src/pai_rag/modules/index/index_utils.py new file mode 100644 index 00000000..8f1ff4e5 --- /dev/null +++ b/src/pai_rag/modules/index/index_utils.py @@ -0,0 +1,109 @@ +import logging +from typing import Any, List, Optional, Sequence + +from llama_index.core.indices.base import BaseIndex +from llama_index.core.indices.composability.graph import ComposableGraph +from llama_index.core.indices.registry import INDEX_STRUCT_TYPE_TO_INDEX_CLASS +from llama_index.core.storage.storage_context import StorageContext +from llama_index.core.data_structs.struct_type import IndexStructType + +from pai_rag.modules.index.my_vector_store_index import MyVectorStoreIndex + + +MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS = INDEX_STRUCT_TYPE_TO_INDEX_CLASS +MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS[ + IndexStructType.VECTOR_STORE +] = MyVectorStoreIndex + +logger = logging.getLogger(__name__) + + +def load_index_from_storage( + storage_context: StorageContext, + index_id: Optional[str] = None, + **kwargs: Any, +) -> BaseIndex: + """Load index from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + index_id (Optional[str]): ID of the index to load. + Defaults to None, which assumes there's only a single index + in the index store and load it. + **kwargs: Additional keyword args to pass to the index constructors. + """ + index_ids: Optional[Sequence[str]] + if index_id is None: + index_ids = None + else: + index_ids = [index_id] + + indices = load_indices_from_storage(storage_context, index_ids=index_ids, **kwargs) + + if len(indices) == 0: + raise ValueError( + "No index in storage context, check if you specified the right persist_dir." + ) + elif len(indices) > 1: + raise ValueError( + f"Expected to load a single index, but got {len(indices)} instead. " + "Please specify index_id." + ) + + return indices[0] + + +def load_indices_from_storage( + storage_context: StorageContext, + index_ids: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> List[BaseIndex]: + """Load multiple indices from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + index_id (Optional[Sequence[str]]): IDs of the indices to load. + Defaults to None, which loads all indices in the index store. + **kwargs: Additional keyword args to pass to the index constructors. + """ + if index_ids is None: + logger.info("Loading all indices.") + index_structs = storage_context.index_store.index_structs() + else: + logger.info(f"Loading indices with ids: {index_ids}") + index_structs = [] + for index_id in index_ids: + index_struct = storage_context.index_store.get_index_struct(index_id) + if index_struct is None: + raise ValueError(f"Failed to load index with ID {index_id}") + index_structs.append(index_struct) + + indices = [] + for index_struct in index_structs: + type_ = index_struct.get_type() + index_cls = MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS[type_] + index = index_cls( + index_struct=index_struct, storage_context=storage_context, **kwargs + ) + indices.append(index) + return indices + + +def load_graph_from_storage( + storage_context: StorageContext, + root_id: str, + **kwargs: Any, +) -> ComposableGraph: + """Load composable graph from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + root_id (str): ID of the root index of the graph. + **kwargs: Additional keyword args to pass to the index constructors. + """ + indices = load_indices_from_storage(storage_context, index_ids=None, **kwargs) + all_indices = {index.index_id: index for index in indices} + return ComposableGraph(all_indices=all_indices, root_id=root_id) diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py new file mode 100644 index 00000000..b4fd3f5a --- /dev/null +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -0,0 +1,119 @@ +"""Base vector store index. + +An index that is built on top of an existing vector store. + +""" + +import asyncio +import logging +from typing import Any, Sequence +from llama_index.core import VectorStoreIndex +from llama_index.core.data_structs.data_structs import IndexDict +from llama_index.core.schema import ( + BaseNode, + IndexNode, +) +from llama_index.core.utils import iter_batch + +logger = logging.getLogger(__name__) + + +def call_async(coro): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + else: + return loop.run_until_complete(coro) + + +class MyVectorStoreIndex(VectorStoreIndex): + async def _process_one_batch( + self, + nodes_batch: Sequence[Sequence[BaseNode]], + index_struct: IndexDict, + semaphore: asyncio.Semaphore, + **insert_kwargs: Any, + ): + async with semaphore: + new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(nodes_batch, new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + + async def _postprocess_all_batch( + self, + nodes_batch_list: Sequence[Sequence[BaseNode]], + index_struct: IndexDict, + **insert_kwargs: Any, + ): + asyncio_semaphore = asyncio.Semaphore(10) + batch_process_coroutines = [] + for nodes_batch in nodes_batch_list: + batch_process_coroutines.append( + self._process_one_batch( + nodes_batch, index_struct, asyncio_semaphore, **insert_kwargs + ) + ) + await asyncio.gather(*batch_process_coroutines) + + async def _async_add_nodes_to_index( + self, + index_struct: IndexDict, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **insert_kwargs: Any, + ) -> None: + """Asynchronously add nodes to index.""" + if not nodes: + return + + node_batch_list = [] + for nodes_batch in iter_batch(nodes, 100): + nodes_batch = await self._aget_node_with_embedding( + nodes_batch, show_progress + ) + node_batch_list.append(nodes_batch) + + await self._postprocess_all_batch( + node_batch_list, index_struct, **insert_kwargs + ) + + async def _insert_async( + self, nodes: Sequence[BaseNode], **insert_kwargs: Any + ) -> None: + """Insert a document.""" + await self._async_add_nodes_to_index( + self._index_struct, nodes, show_progress=True, **insert_kwargs + ) + + async def insert_nodes_async( + self, nodes: Sequence[BaseNode], **insert_kwargs: Any + ) -> None: + """Insert nodes. + + NOTE: overrides BaseIndex.insert_nodes. + VectorStoreIndex only stores nodes in document store + if vector store does not store text + """ + for node in nodes: + if isinstance(node, IndexNode): + try: + node.dict() + except ValueError: + self._object_map[node.index_id] = node.obj + node.obj = None + + with self._callback_manager.as_trace("insert_nodes"): + await self._insert_async(nodes, **insert_kwargs) + self._storage_context.index_store.add_index_struct(self._index_struct) diff --git a/src/pai_rag/modules/index/store.py b/src/pai_rag/modules/index/store.py index 1a290cdd..4d423f5b 100644 --- a/src/pai_rag/modules/index/store.py +++ b/src/pai_rag/modules/index/store.py @@ -6,7 +6,6 @@ from llama_index.vector_stores.analyticdb import AnalyticDBVectorStore from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore -from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.milvus import MilvusVectorStore from elasticsearch.helpers.vectorstore import AsyncDenseVectorStrategy @@ -16,6 +15,8 @@ from llama_index.core import StorageContext import logging +from pai_rag.modules.retriever.my_elasticsearch_store import MyElasticsearchStore + DEFAULT_CHROMA_COLLECTION_NAME = "pairag" logger = logging.getLogger(__name__) @@ -36,14 +37,19 @@ def _get_or_create_storage_context(self): self.vector_store = None self.doc_store = None self.index_store = None + persist_dir = None vector_store_type = ( - self.store_config["vector_store"].get("type", "chroma").lower() + self.store_config["vector_store"].get("type", "faiss").lower() ) + if vector_store_type == "chroma": self.vector_store = self._get_or_create_chroma() logger.info("initialized Chroma vector store.") elif vector_store_type == "faiss": + self.doc_store = self._get_or_create_simple_doc_store() + self.index_store = self._get_or_create_simple_index_store() + persist_dir = self.persist_dir self.vector_store = self._get_or_create_faiss() logger.info("initialized FAISS vector store.") elif vector_store_type == "hologres": @@ -60,14 +66,11 @@ def _get_or_create_storage_context(self): else: raise ValueError(f"Unknown vector_store type '{vector_store_type}'.") - self.doc_store = self._get_or_create_simple_doc_store() - self.index_store = self._get_or_create_simple_index_store() - storage_context = StorageContext.from_defaults( docstore=self.doc_store, index_store=self.index_store, vector_store=self.vector_store, - persist_dir=self.persist_dir, + persist_dir=persist_dir, ) return storage_context @@ -124,12 +127,15 @@ def _get_or_create_adb(self): def _get_or_create_es(self): es_config = self.store_config["vector_store"] - return ElasticsearchStore( + return MyElasticsearchStore( index_name=es_config["es_index"], es_url=es_config["es_url"], es_user=es_config["es_user"], es_password=es_config["es_password"], - retrieval_strategy=AsyncDenseVectorStrategy(hybrid=True), + embedding_dimension=self.embed_dims, + retrieval_strategy=AsyncDenseVectorStrategy( + hybrid=True, rrf={"window_size": 50} + ), ) def _get_or_create_milvus(self): diff --git a/src/pai_rag/modules/module_registry.py b/src/pai_rag/modules/module_registry.py index 12c65ce3..cc2dca46 100644 --- a/src/pai_rag/modules/module_registry.py +++ b/src/pai_rag/modules/module_registry.py @@ -1,5 +1,8 @@ +import hashlib +from typing import Dict, Any from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG import pai_rag.modules as modules +import logging MODULE_CONFIG_KEY_MAP = { "IndexModule": "index", @@ -16,17 +19,25 @@ "DataReaderFactoryModule": "data_reader", "AgentModule": "agent", "ToolModule": "tool", + "DataLoaderModule": "data_loader", + "OssCacheModule": "cache", } +logger = logging.getLogger(__name__) + + class ModuleRegistry: def __init__(self): - self._mod_instance_map = {} self._mod_cls_map = {} self._mod_deps_map = {} self._mod_deps_map_inverted = {} + self._mod_instance_map = {} + for m_name in modules.ALL_MODULES: + self._mod_instance_map[m_name] = {} + m_cls = getattr(modules, m_name) self._mod_cls_map[m_name] = m_cls() @@ -38,12 +49,16 @@ def __init__(self): self._mod_deps_map_inverted[dep] = [] self._mod_deps_map_inverted[dep].append(m_name) - def get_module(self, module_key: str): - return self._mod_instance_map[module_key] + def _get_param_hash(self, params: Dict[str, Any]): + repr_str = repr(sorted(params.items())).encode("utf-8") + return hashlib.sha256(repr_str).hexdigest() + + def get_module_with_config(self, module_key, config): + return self._create_mod_lazily(module_key, config) def init_modules(self, config): + mod_cache = {} mod_stack = [] - mods_inited = [] mod_ref_count = {} for mod, deps in self._mod_deps_map.items(): ref_count = len(deps) @@ -53,9 +68,8 @@ def init_modules(self, config): while mod_stack: mod = mod_stack.pop() - mod_obj = self._init_mod(mod, config) - mods_inited.append(mod) - self._mod_instance_map[mod] = mod_obj + mod_obj = self._create_mod_lazily(mod, config, mod_cache) + mod_cache[mod] = mod_obj # update module ref count that depends on on ref_mods = self._mod_deps_map_inverted.get(mod, []) @@ -64,23 +78,35 @@ def init_modules(self, config): if mod_ref_count[ref_mod] == 0: mod_stack.append(ref_mod) - if len(mods_inited) != len(modules.ALL_MODULES): + if len(mod_cache) != len(modules.ALL_MODULES): # dependency circular error! raise ValueError( - f"Circular dependency detected. Please check module dependency configuration. Module initialized: {mods_inited}. Module ref count: {mod_ref_count}" + f"Circular dependency detected. Please check module dependency configuration. Module initialized: {mod_cache}. Module ref count: {mod_ref_count}" ) - print(f"RAG modules init successfully. {mods_inited}") + logger.info(f"RAG modules init successfully. {mod_cache.keys()}") return - def _init_mod(self, mod_name, config): + def _create_mod_lazily(self, mod_name, config, mod_cache=None): + if mod_cache and mod_name in mod_cache: + return mod_cache[mod_name] + + logger.info(f"Get module {mod_name}.") + mod_config_key = MODULE_CONFIG_KEY_MAP[mod_name] mod_deps = self._mod_deps_map[mod_name] mod_cls = self._mod_cls_map[mod_name] - params = {MODULE_PARAM_CONFIG: config[mod_config_key]} + params = {MODULE_PARAM_CONFIG: config.get(mod_config_key, None)} for dep in mod_deps: - params[dep] = self.get_module(dep) - return mod_cls.get_or_create(params) + params[dep] = self._create_mod_lazily(dep, config, mod_cache) + + instance_key = self._get_param_hash(params) + if instance_key not in self._mod_instance_map[mod_name]: + logger.info(f"Creating new instance for module {mod_name} {instance_key}.") + self._mod_instance_map[mod_name][instance_key] = mod_cls.get_or_create( + params + ) + return self._mod_instance_map[mod_name][instance_key] module_registry = ModuleRegistry() diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py new file mode 100644 index 00000000..63205ccc --- /dev/null +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -0,0 +1,538 @@ +"""Elasticsearch vector store.""" + +import asyncio +from logging import getLogger +from typing import Any, Callable, Dict, List, Literal, Optional, Union + +import nest_asyncio +import numpy as np +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) +from elasticsearch.helpers.vectorstore import AsyncVectorStore +from elasticsearch.helpers.vectorstore import ( + AsyncBM25Strategy, + AsyncSparseVectorStrategy, + AsyncDenseVectorStrategy, + AsyncRetrievalStrategy, + DistanceMetric, +) + +from llama_index.vector_stores.elasticsearch.utils import ( + get_elasticsearch_client, + get_user_agent, +) + +logger = getLogger(__name__) + +DISTANCE_STRATEGIES = Literal[ + "COSINE", + "DOT_PRODUCT", + "EUCLIDEAN_DISTANCE", +] + + +def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]: + """ + Convert standard filters to Elasticsearch filter. + + Args: + standard_filters: Standard Llama-index filters. + + Returns: + Elasticsearch filter. + """ + if len(standard_filters.legacy_filters()) == 1: + filter = standard_filters.legacy_filters()[0] + return { + "term": { + f"metadata.{filter.key}.keyword": { + "value": filter.value, + } + } + } + else: + operands = [] + for filter in standard_filters.legacy_filters(): + operands.append( + { + "term": { + f"metadata.{filter.key}.keyword": { + "value": filter.value, + } + } + } + ) + return {"bool": {"must": operands}} + + +def _to_llama_similarities(scores: List[float]) -> List[float]: + if scores is None or len(scores) == 0: + return [] + + scores_to_norm: np.ndarray = np.array(scores) + return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist() + + +def _mode_must_match_retrieval_strategy( + mode: VectorStoreQueryMode, retrieval_strategy: AsyncRetrievalStrategy +) -> None: + """ + Different retrieval strategies require different ways of indexing that must be known at the + time of adding data. The query mode is known at query time. This function checks if the + retrieval strategy (and way of indexing) is compatible with the query mode and raises and + exception in the case of a mismatch. + """ + if mode == VectorStoreQueryMode.DEFAULT: + # it's fine to not specify an explicit other mode + return + + mode_retrieval_dict = { + VectorStoreQueryMode.SPARSE: AsyncSparseVectorStrategy, + VectorStoreQueryMode.TEXT_SEARCH: AsyncBM25Strategy, + VectorStoreQueryMode.HYBRID: AsyncDenseVectorStrategy, + } + + required_strategy = mode_retrieval_dict.get(mode) + if not required_strategy: + raise NotImplementedError(f"query mode {mode} currently not supported") + + if not isinstance(retrieval_strategy, required_strategy): + raise ValueError( + f"query mode {mode} incompatible with retrieval strategy {type(retrieval_strategy)}, " + f"expected {required_strategy}" + ) + + if mode == VectorStoreQueryMode.HYBRID and not retrieval_strategy.hybrid: + raise ValueError("to enable hybrid mode, it must be set in retrieval strategy") + + +class MyElasticsearchStore(BasePydanticVectorStore): + """ + Elasticsearch vector store. + + Args: + index_name: Name of the Elasticsearch index. + es_client: Optional. Pre-existing AsyncElasticsearch client. + es_url: Optional. Elasticsearch URL. + es_cloud_id: Optional. Elasticsearch cloud ID. + es_api_key: Optional. Elasticsearch API key. + es_user: Optional. Elasticsearch username. + es_password: Optional. Elasticsearch password. + text_field: Optional. Name of the Elasticsearch field that stores the text. + vector_field: Optional. Name of the Elasticsearch field that stores the + embedding. + batch_size: Optional. Batch size for bulk indexing. Defaults to 200. + distance_strategy: Optional. Distance strategy to use for similarity search. + Defaults to "COSINE". + retrieval_strategy: Retrieval strategy to use. AsyncBM25Strategy / + AsyncSparseVectorStrategy / AsyncDenseVectorStrategy / AsyncRetrievalStrategy. + Defaults to AsyncDenseVectorStrategy. + + Raises: + ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch. + ValueError: If neither es_client nor es_url nor es_cloud_id is provided. + + Examples: + `pip install llama-index-vector-stores-elasticsearch` + + ```python + from llama_index.vector_stores import ElasticsearchStore + + # Additional setup for ElasticsearchStore class + index_name = "my_index" + es_url = "http://localhost:9200" + es_cloud_id = "" # Found within the deployment page + es_user = "elastic" + es_password = "" # Provided when creating deployment or can be reset + es_api_key = "" # Create an API key within Kibana (Security -> API Keys) + + # Connecting to ElasticsearchStore locally + es_local = ElasticsearchStore( + index_name=index_name, + es_url=es_url, + ) + + # Connecting to Elastic Cloud with username and password + es_cloud_user_pass = ElasticsearchStore( + index_name=index_name, + es_cloud_id=es_cloud_id, + es_user=es_user, + es_password=es_password, + ) + + # Connecting to Elastic Cloud with API Key + es_cloud_api_key = ElasticsearchStore( + index_name=index_name, + es_cloud_id=es_cloud_id, + es_api_key=es_api_key, + ) + ``` + + """ + + class Config: + # allow pydantic to tolarate its inability to validate AsyncRetrievalStrategy + arbitrary_types_allowed = True + + stores_text: bool = True + index_name: str + es_client: Optional[Any] + es_url: Optional[str] + es_cloud_id: Optional[str] + es_api_key: Optional[str] + es_user: Optional[str] + es_password: Optional[str] + text_field: str = "content" + vector_field: str = "embedding" + batch_size: int = 200 + distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE" + retrieval_strategy: AsyncRetrievalStrategy + + _store = PrivateAttr() + + def __init__( + self, + index_name: str, + es_client: Optional[Any] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + es_user: Optional[str] = None, + es_password: Optional[str] = None, + text_field: str = "content", + vector_field: str = "embedding", + embedding_dimension: int = 1536, + batch_size: int = 200, + distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE", + retrieval_strategy: Optional[AsyncRetrievalStrategy] = None, + ) -> None: + nest_asyncio.apply() + + if not es_client: + es_client = get_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + ) + + if retrieval_strategy is None: + retrieval_strategy = AsyncDenseVectorStrategy( + distance=DistanceMetric[distance_strategy] + ) + + metadata_mappings = { + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + } + + self._store = AsyncVectorStore( + user_agent=get_user_agent(), + client=es_client, + index=index_name, + retrieval_strategy=retrieval_strategy, + text_field=text_field, + vector_field=vector_field, + metadata_mappings=metadata_mappings, + num_dimensions=embedding_dimension, + ) + asyncio.get_event_loop().run_until_complete( + self._store._create_index_if_not_exists() + ) + + super().__init__( + index_name=index_name, + es_client=es_client, + es_url=es_url, + es_cloud_id=es_cloud_id, + es_api_key=es_api_key, + es_user=es_user, + es_password=es_password, + text_field=text_field, + vector_field=vector_field, + batch_size=batch_size, + distance_strategy=distance_strategy, + retrieval_strategy=retrieval_strategy, + ) + + @property + def client(self) -> Any: + """Get async elasticsearch client.""" + return self._store.client + + def close(self) -> None: + return asyncio.get_event_loop().run_until_complete(self._store.close()) + + def add( + self, + nodes: List[BaseNode], + *, + create_index_if_not_exists: bool = True, + **add_kwargs: Any, + ) -> List[str]: + """ + Add nodes to Elasticsearch index. + + Args: + nodes: List of nodes with embeddings. + create_index_if_not_exists: Optional. Whether to create + the Elasticsearch index if it + doesn't already exist. + Defaults to True. + + Returns: + List of node IDs that were added to the index. + + Raises: + ImportError: If elasticsearch['async'] python package is not installed. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists) + ) + + async def async_add( + self, + nodes: List[BaseNode], + *, + create_index_if_not_exists: bool = True, + **add_kwargs: Any, + ) -> List[str]: + """ + Asynchronous method to add nodes to Elasticsearch index. + + Args: + nodes: List of nodes with embeddings. + create_index_if_not_exists: Optional. Whether to create + the AsyncElasticsearch index if it + doesn't already exist. + Defaults to True. + + Returns: + List of node IDs that were added to the index. + + Raises: + ImportError: If elasticsearch python package is not installed. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. + """ + if len(nodes) == 0: + return [] + + add_kwargs.update({"max_retries": 3}) + + embeddings: List[List[float]] = [] + texts: List[str] = [] + metadatas: List[dict] = [] + ids: List[str] = [] + for node in nodes: + ids.append(node.node_id) + embeddings.append(node.get_embedding()) + texts.append(node.get_content(metadata_mode=MetadataMode.NONE)) + metadatas.append(node_to_metadata_dict(node, remove_text=True)) + + if not self._store.num_dimensions: + self._store.num_dimensions = len(embeddings[0]) + + return await self._store.add_texts( + texts=texts, + metadatas=metadatas, + vectors=embeddings, + ids=ids, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=add_kwargs, + ) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Delete node from Elasticsearch index. + + Args: + ref_doc_id: ID of the node to delete. + delete_kwargs: Optional. Additional arguments to + pass to Elasticsearch delete_by_query. + + Raises: + Exception: If Elasticsearch delete_by_query fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.adelete(ref_doc_id, **delete_kwargs) + ) + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Async delete node from Elasticsearch index. + + Args: + ref_doc_id: ID of the node to delete. + delete_kwargs: Optional. Additional arguments to + pass to AsyncElasticsearch delete_by_query. + + Raises: + Exception: If AsyncElasticsearch delete_by_query fails. + """ + await self._store.delete( + query={"term": {"metadata.ref_doc_id": ref_doc_id}}, **delete_kwargs + ) + + def query( + self, + query: VectorStoreQuery, + custom_query: Optional[ + Callable[[Dict, Union[VectorStoreQuery, None]], Dict] + ] = None, + es_filter: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """ + Query index for top k most similar nodes. + + Args: + query_embedding (List[float]): query embedding + custom_query: Optional. custom query function that takes in the es query + body and returns a modified query body. + This can be used to add additional query + parameters to the Elasticsearch query. + es_filter: Optional. Elasticsearch filter to apply to the + query. If filter is provided in the query, + this filter will be ignored. + + Returns: + VectorStoreQueryResult: Result of the query. + + Raises: + Exception: If Elasticsearch query fails. + + """ + return asyncio.get_event_loop().run_until_complete( + self.aquery(query, custom_query, es_filter, **kwargs) + ) + + async def aquery( + self, + query: VectorStoreQuery, + custom_query: Optional[ + Callable[[Dict, Union[VectorStoreQuery, None]], Dict] + ] = None, + es_filter: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """ + Asynchronous query index for top k most similar nodes. + + Args: + query_embedding (VectorStoreQuery): query embedding + custom_query: Optional. custom query function that takes in the es query + body and returns a modified query body. + This can be used to add additional query + parameters to the AsyncElasticsearch query. + es_filter: Optional. AsyncElasticsearch filter to apply to the + query. If filter is provided in the query, + this filter will be ignored. + + Returns: + VectorStoreQueryResult: Result of the query. + + Raises: + Exception: If AsyncElasticsearch query fails. + + """ + if query.mode == VectorStoreQueryMode.HYBRID: + retrieval_strategy = AsyncDenseVectorStrategy( + hybrid=True, rrf={"window_size": 50} + ) + elif query.mode == VectorStoreQueryMode.TEXT_SEARCH: + retrieval_strategy = AsyncBM25Strategy() + else: + retrieval_strategy = AsyncDenseVectorStrategy() + + metadata_mappings = { + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + } + self._store = AsyncVectorStore( + user_agent=get_user_agent(), + client=self.es_client, + index=self.index_name, + retrieval_strategy=retrieval_strategy, + text_field=self.text_field, + vector_field=self.vector_field, + metadata_mappings=metadata_mappings, + ) + + if query.filters is not None and len(query.filters.legacy_filters()) > 0: + filter = [_to_elasticsearch_filter(query.filters)] + else: + filter = es_filter or [] + + hits = await self._store.search( + query=query.query_str, + query_vector=query.query_embedding, + k=query.similarity_top_k, + num_candidates=query.similarity_top_k * 10, + filter=filter, + custom_query=custom_query, + ) + + top_k_nodes = [] + top_k_ids = [] + top_k_scores = [] + for hit in hits: + source = hit["_source"] + metadata = source.get("metadata", None) + text = source.get(self.text_field, None) + node_id = hit["_id"] + + try: + node = metadata_dict_to_node(metadata) + node.text = text + except Exception: + # Legacy support for old metadata format + logger.warning( + f"Could not parse metadata from hit {hit['_source']['metadata']}" + ) + node_info = source.get("node_info") + relationships = source.get("relationships", {}) + start_char_idx = None + end_char_idx = None + if isinstance(node_info, dict): + start_char_idx = node_info.get("start", None) + end_char_idx = node_info.get("end", None) + + node = TextNode( + text=text, + metadata=metadata, + id_=node_id, + start_char_idx=start_char_idx, + end_char_idx=end_char_idx, + relationships=relationships, + ) + top_k_nodes.append(node) + top_k_ids.append(node_id) + top_k_scores.append(hit.get("_rank", hit["_score"])) + + if ( + isinstance(self.retrieval_strategy, AsyncDenseVectorStrategy) + and self.retrieval_strategy.hybrid + ): + total_rank = sum(top_k_scores) + top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores] + + return VectorStoreQueryResult( + nodes=top_k_nodes, + ids=top_k_ids, + similarities=_to_llama_similarities(top_k_scores), + ) diff --git a/src/pai_rag/modules/retriever/my_vector_index_retriever.py b/src/pai_rag/modules/retriever/my_vector_index_retriever.py index 16c98946..755f02a6 100644 --- a/src/pai_rag/modules/retriever/my_vector_index_retriever.py +++ b/src/pai_rag/modules/retriever/my_vector_index_retriever.py @@ -25,7 +25,7 @@ class MyVectorIndexRetriever(VectorIndexRetriever): and return the results with the query_result.similarities sorted in descending order. Args: - index (VectorStoreIndex): vector store index. + index (MyVectorIndexRetriever): vector store index. similarity_top_k (int): number of top k results to return. vector_store_query_mode (str): vector store query mode See reference for VectorStoreQueryMode for full list of supported modes. diff --git a/src/pai_rag/modules/retriever/retriever.py b/src/pai_rag/modules/retriever/retriever.py index bc7cbca3..d65b4a33 100644 --- a/src/pai_rag/modules/retriever/retriever.py +++ b/src/pai_rag/modules/retriever/retriever.py @@ -3,29 +3,23 @@ import logging from typing import Dict, List, Any -import jieba -from nltk.corpus import stopwords from llama_index.core.indices.list.base import SummaryIndex from llama_index.core.retrievers import QueryFusionRetriever from llama_index.core.tools import RetrieverTool from llama_index.core.selectors import LLMSingleSelector from llama_index.core.retrievers import RouterRetriever +from llama_index.core.vector_stores.types import VectorStoreQueryMode -# from llama_index.retrievers.bm25 import BM25Retriever +from pai_rag.utils.tokenizer import jieba_tokenizer from pai_rag.integrations.retrievers.bm25 import BM25Retriever from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.utils.prompt_template import QUERY_GEN_PROMPT +from pai_rag.modules.retriever.my_elasticsearch_store import MyElasticsearchStore from pai_rag.modules.retriever.my_vector_index_retriever import MyVectorIndexRetriever logger = logging.getLogger(__name__) -stopword_list = stopwords.words("chinese") + stopwords.words("english") - - -def jieba_tokenize(text: str) -> List[str]: - return [w for w in jieba.lcut(text) if w not in stopword_list] - class RetrieverModule(ConfigurableModule): @staticmethod @@ -37,23 +31,39 @@ def _create_new_instance(self, new_params: Dict[str, Any]): vector_index = new_params["IndexModule"] similarity_top_k = config.get("similarity_top_k", 5) - # vector + + retrieval_mode = config.get("retrieval_mode", "hybrid").lower() + + # Special handle elastic search + if isinstance(vector_index.storage_context.vector_store, MyElasticsearchStore): + if retrieval_mode == "embedding": + query_mode = VectorStoreQueryMode.DEFAULT + elif retrieval_mode == "keyword": + query_mode = VectorStoreQueryMode.TEXT_SEARCH + else: + query_mode = VectorStoreQueryMode.HYBRID + + return MyVectorIndexRetriever( + index=vector_index, + similarity_top_k=similarity_top_k, + vector_store_query_mode=query_mode, + ) + vector_retriever = MyVectorIndexRetriever( index=vector_index, similarity_top_k=similarity_top_k ) - # keyword bm25_retriever = BM25Retriever.from_defaults( index=vector_index, similarity_top_k=similarity_top_k, - tokenizer=jieba_tokenize, + tokenizer=jieba_tokenizer, ) - if config["retrieval_mode"] == "embedding": + if retrieval_mode == "embedding": logger.info(f"MyVectorIndexRetriever used with top_k {similarity_top_k}.") return vector_retriever - elif config["retrieval_mode"] == "keyword": + elif retrieval_mode == "keyword": logger.info(f"BM25Retriever used with top_k {similarity_top_k}.") return bm25_retriever diff --git a/src/pai_rag/utils/store_utils.py b/src/pai_rag/utils/store_utils.py index a42bb42c..0873375c 100644 --- a/src/pai_rag/utils/store_utils.py +++ b/src/pai_rag/utils/store_utils.py @@ -14,9 +14,9 @@ def get_store_persist_directory_name(storage_config, ndims): raw_text = "sample_store_key" vector_store_type = storage_config["vector_store"]["type"].lower() if vector_store_type == "chroma": - raw_text = json.dumps(storage_config["vector_store"]) + raw_text = json.dumps(storage_config["vector_store"], sort_keys=True) elif vector_store_type == "faiss": - raw_text = json.dumps(storage_config["vector_store"]) + raw_text = {"type": "faiss"} elif vector_store_type == "hologres": keywords = ["host", "port", "database", "table_name"] json_data = {k: storage_config["vector_store"][k] for k in keywords} diff --git a/src/pai_rag/utils/tokenizer.py b/src/pai_rag/utils/tokenizer.py new file mode 100644 index 00000000..168702b9 --- /dev/null +++ b/src/pai_rag/utils/tokenizer.py @@ -0,0 +1,16 @@ +import jieba +from nltk.corpus import stopwords +from typing import List + +stopword_list = stopwords.words("chinese") + stopwords.words("english") + + +## PUT in utils file and add stopword in TRIE structure. +def jieba_tokenizer(text: str) -> List[str]: + tokens = [] + for w in jieba.lcut(text): + token = w.lower() + if token not in stopword_list: + tokens.append(token) + + return tokens diff --git a/tests/core/test_rag_application.py b/tests/core/test_rag_application.py index 457d599c..ef27858f 100644 --- a/tests/core/test_rag_application.py +++ b/tests/core/test_rag_application.py @@ -28,12 +28,9 @@ def rag_app(): # Test load knowledge file -def test_add_knowledge_file(rag_app: RagApplication): +async def test_add_knowledge_file(rag_app: RagApplication): data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham") - print(len(rag_app.index.docstore.docs)) - rag_app.load_knowledge(data_dir) - print(len(rag_app.index.docstore.docs)) - assert len(rag_app.index.docstore.docs) > 0 + await rag_app.load_knowledge(data_dir) # Test rag query