From e5eb868e0bd0c5e4d175b5935f2d495b8e220b21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 09:25:31 +0800 Subject: [PATCH 01/17] Add gpu dockerfile --- .github/workflows/docker.yml | 11 +++++- Dockerfile_gpu | 26 +++++++++++++ pyproject_gpu.toml | 75 ++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 Dockerfile_gpu create mode 100644 pyproject_gpu.toml diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 1f0ab779..d12507b8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -34,7 +34,7 @@ jobs: username: ${{ secrets.ACR_USER }} password: ${{ secrets.ACR_PASSWORD }} - - name: Build and push image + - name: Build and push base image env: IMAGE_TAG: 0.0.1 run: | @@ -42,3 +42,12 @@ jobs: docker tag ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG docker push ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG docker push ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG + + - name: Build and push GPU image + env: + IMAGE_TAG: 0.0.1_gpu + run: | + docker build -t ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG -f Dockerfile_gpu . + docker tag ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG + docker push ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG + docker push ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG diff --git a/Dockerfile_gpu b/Dockerfile_gpu new file mode 100644 index 00000000..d6313cf0 --- /dev/null +++ b/Dockerfile_gpu @@ -0,0 +1,26 @@ +FROM python:3.10-slim AS builder + +RUN pip3 install poetry + +ENV POETRY_NO_INTERACTION=1 \ + POETRY_VIRTUALENVS_IN_PROJECT=1 \ + POETRY_VIRTUALENVS_CREATE=1 \ + POETRY_CACHE_DIR=/tmp/poetry_cache + +WORKDIR /app +COPY . . +RUN mv pyproject_gpu.toml pyproject.toml \ + && rm poetry.lock + +RUN poetry install && rm -rf $POETRY_CACHE_DIR + +FROM python:3.10-slim AS prod +ENV VIRTUAL_ENV=/app/.venv \ + PATH="/app/.venv/bin:$PATH" + +RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 + +WORKDIR /app +COPY . . +COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} +ENTRYPOINT ["pai_rag", "run"] diff --git a/pyproject_gpu.toml b/pyproject_gpu.toml new file mode 100644 index 00000000..e68206eb --- /dev/null +++ b/pyproject_gpu.toml @@ -0,0 +1,75 @@ +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.poetry] +name = "pai_rag" +version = "0.1.0" +description = "Open source RAG framework built on Aliyun PAI" +authors = [] +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.10.0,<3.12" +fastapi = "^0.110.1" +uvicorn = "^0.29.0" +llama-index-core = ">=0.10.29,<=0.10.39" +llama-index-embeddings-openai = "^0.1.7" +llama-index-embeddings-azure-openai = "^0.1.7" +llama-index-embeddings-dashscope = "^0.1.3" +llama-index-llms-openai = "^0.1.15" +llama-index-llms-azure-openai = "^0.1.6" +llama-index-llms-dashscope = "^0.1.2" +llama-index-readers-database = "^0.1.3" +llama-index-vector-stores-chroma = "^0.1.6" +llama-index-vector-stores-faiss = "^0.1.2" +llama-index-vector-stores-analyticdb = "^0.1.1" +llama-index-vector-stores-elasticsearch = "^0.2.0" +llama-index-vector-stores-milvus = "^0.1.10" +gradio = "3.41.0" +faiss-cpu = "^1.8.0" +hologres-vector = "^0.0.9" +dynaconf = "^3.2.5" +docx2txt = "^0.8" +click = "^8.1.7" +pydantic = "^2.7.0" +pytest = "^8.1.1" +llama-index-retrievers-bm25 = "^0.1.3" +jieba = "^0.42.1" +llama-index-embeddings-huggingface = "^0.2.0" +llama-index-postprocessor-flag-embedding-reranker = "^0.1.3" +flagembedding = "^1.2.10" +sentencepiece = "^0.2.0" +oss2 = "^2.18.5" +asgi-correlation-id = "^4.3.1" +openinference-instrumentation-llama-index = "1.3.0" +torch = "2.2.2" +torchvision = "0.17.2" +openpyxl = "^3.1.2" +pdf2image = "^1.17.0" +llama-index-storage-chat-store-redis = "^0.1.3" +easyocr = "^1.7.1" +opencv-python = "^4.9.0.80" +llama-parse = "0.4.2" +pypdf2 = "^3.0.1" +pdfplumber = "^0.11.0" +pdfminer-six = "^20231228" +openinference-semantic-conventions = "0.1.6" +llama-index-tools-google = "^0.1.5" +llama-index-tools-duckduckgo = "^0.1.1" +openinference-instrumentation = "^0.1.7" +llama-index-llms-huggingface = "^0.2.0" +pytest-asyncio = "^0.23.7" +pytest-cov = "^5.0.0" +xlrd = "^2.0.1" +markdown = "^3.6" +chardet = "^5.2.0" + +[tool.poetry.scripts] +pai_rag = "pai_rag.main:main" +load_data = "pai_rag.data.rag_datapipeline:run" +load_easyocr_model = "pai_rag.utils.download_easyocr_models:download_easyocr_models" +evaluation = "pai_rag.evaluations.batch_evaluator:run" + +[tool.pytest.ini_options] +asyncio_mode = "auto" From 52cd2386ab4aa1435f9a05ad4891156a9877d788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 09:59:17 +0800 Subject: [PATCH 02/17] Fix bug --- src/pai_rag/app/web/view_model.py | 2 +- src/pai_rag/data/rag_dataloader.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 2b46c299..65eaf034 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -298,7 +298,7 @@ def to_app_config(self): config["postprocessor"]["rerank_model"] = "bge-reranker-large" else: config["postprocessor"]["rerank_model"] = "no-reranker" - config["postprocessor"]["top_n"] = 3 + config["postprocessor"]["top_n"] = self.similarity_top_k config["synthesizer"]["type"] = self.synthesizer_type config["synthesizer"]["text_qa_template"] = self.text_qa_template diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index 3dd0cd80..a4d3c5a1 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -60,6 +60,8 @@ def _extract_file_type(self, metadata: Dict[str, Any]): async def load(self, file_directory: str, enable_qa_extraction: bool): data_reader = self.datareader_factory.get_reader(file_directory) docs = data_reader.load_data() + logger.info(f"[DataReader] Loaded {len(docs)} docs.") + nodes = [] doc_cnt_map = {} @@ -78,6 +80,8 @@ async def load(self, file_directory: str, enable_qa_extraction: bool): else: nodes.extend(self.node_parser.get_nodes_from_documents([doc])) + logger.info(f"[DataReader] Split into {len(nodes)} nodes.") + # QA metadata extraction if enable_qa_extraction: qa_nodes = [] @@ -103,6 +107,8 @@ async def load(self, file_directory: str, enable_qa_extraction: bool): node.excluded_llm_metadata_keys.append("question") nodes.extend(qa_nodes) + logger.info("[DataReader] Start inserting to index.") + self.index.insert_nodes(nodes) self.index.storage_context.persist(persist_dir=store_path.persist_path) logger.info(f"Inserted {len(nodes)} nodes successfully.") From 7fc89ac389c7ba8e9d3ed7da96b6d0a7f4765fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 10:06:52 +0800 Subject: [PATCH 03/17] Fix gb2312 --- src/pai_rag/integrations/readers/pai_csv_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pai_rag/integrations/readers/pai_csv_reader.py b/src/pai_rag/integrations/readers/pai_csv_reader.py index 3b673c7f..0b653395 100644 --- a/src/pai_rag/integrations/readers/pai_csv_reader.py +++ b/src/pai_rag/integrations/readers/pai_csv_reader.py @@ -141,7 +141,7 @@ def load_data( with fs.open(file) as f: encoding = chardet.detect(f.read(100000))["encoding"] f.seek(0) - if encoding.upper() in ["GB18030", "GBK"]: + if "GB" in encoding.upper(): self._pandas_config["encoding"] = "GB18030" try: df = pd.read_csv(f, **self._pandas_config) From 74ed1d44d6925c12da2a454c627ea89a2778495a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 10:22:37 +0800 Subject: [PATCH 04/17] Update embedding batch size --- src/pai_rag/modules/embedding/embedding.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/pai_rag/modules/embedding/embedding.py b/src/pai_rag/modules/embedding/embedding.py index 24043bf9..dd5dc3bb 100644 --- a/src/pai_rag/modules/embedding/embedding.py +++ b/src/pai_rag/modules/embedding/embedding.py @@ -45,9 +45,15 @@ def _create_new_instance(self, new_params: Dict[str, Any]): elif source == "huggingface": model_dir = config.get("model_dir", DEFAULT_MODEL_DIR) model_name = config.get("model_name", DEFAULT_HUGGINGFACE_EMBEDDING_MODEL) + embed_batch_size = config.get("embed_batch_size", DEFAULT_EMBED_BATCH_SIZE) + model_path = os.path.join(model_dir, model_name) - embed_model = HuggingFaceEmbedding(model_name=model_path) - logger.info("Initialized HuggingFace embedding model.") + embed_model = HuggingFaceEmbedding( + model_name=model_path, embed_batch_size=embed_batch_size + ) + logger.info( + f"Initialized HuggingFace embedding model {model_name} with {embed_batch_size} batch size." + ) elif source == "dashscope": embed_model = DashScopeEmbedding( From 7bb44a6a66c9b155f01bbed85726a08fb7554897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 15:07:53 +0800 Subject: [PATCH 05/17] Set default embedding and llm model --- .github/workflows/main.yml | 3 +++ src/pai_rag/app/web/view_model.py | 2 +- src/pai_rag/config/settings.toml | 8 +++++--- src/pai_rag/modules/llm/llm_module.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cea6d2a7..facfc88b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -50,6 +50,9 @@ jobs: env: DASHSCOPE_API_KEY: ${{ secrets.TESTDASHSCOPEKEY }} IS_PAI_RAG_CI_TEST: true + PAIRAG_RAG__embedding__source: "DashScope" + PAIRAG_RAG__llm__source: "DashScope" + PAIRAG_RAG__llm__name: "qwen-turbo" - name: Get Cover uses: orgoro/coverage@v3.1 diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index 65eaf034..f44e9090 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -32,7 +32,7 @@ class ViewModel(BaseModel): llm: str = "PaiEas" llm_eas_url: str = None llm_eas_token: str = None - llm_eas_model_name: str = None + llm_eas_model_name: str = "PAI-EAS-LLM" llm_api_key: str = None llm_api_model_name: str = None llm_temperature: float = 0.1 diff --git a/src/pai_rag/config/settings.toml b/src/pai_rag/config/settings.toml index d610039e..8eb3ddf9 100644 --- a/src/pai_rag/config/settings.toml +++ b/src/pai_rag/config/settings.toml @@ -20,7 +20,8 @@ persist_path = "localdata/storage" type = "SimpleDirectoryReader" [rag.embedding] -source = "DashScope" +source = "HuggingFace" +model_name = "bge-small-zh-v1.5" [rag.evaluation] retrieval = ["mrr", "hit_rate"] @@ -32,8 +33,9 @@ persist_path = "localdata/storage" vector_store.type = "FAISS" [rag.llm] -source = "DashScope" -name = "qwen-turbo" +source = "PaiEas" +endpoint = "" +token = "" [rag.llm_chat_engine] type = "SimpleChatEngine" diff --git a/src/pai_rag/modules/llm/llm_module.py b/src/pai_rag/modules/llm/llm_module.py index 09482cbf..44b98758 100644 --- a/src/pai_rag/modules/llm/llm_module.py +++ b/src/pai_rag/modules/llm/llm_module.py @@ -60,7 +60,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): model_name=model_name, temperature=config.get("temperature", 0.1) ) elif source == "paieas": - model_name = config["name"] + model_name = config.get("name", "PAI-EAS-LLM") endpoint = config["endpoint"] token = config["token"] logger.info( From 7784254edc1719c677e166d88aed0ab3d32d2e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 15:08:39 +0800 Subject: [PATCH 06/17] Update docker tag --- .github/workflows/docker.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d12507b8..d45cb254 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -36,7 +36,7 @@ jobs: - name: Build and push base image env: - IMAGE_TAG: 0.0.1 + IMAGE_TAG: 0.0.2 run: | docker build -t ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG . docker tag ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG @@ -45,7 +45,7 @@ jobs: - name: Build and push GPU image env: - IMAGE_TAG: 0.0.1_gpu + IMAGE_TAG: 0.0.2_gpu run: | docker build -t ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG -f Dockerfile_gpu . docker tag ${{ env.REGISTRY }}/mybigpai/pairag:$IMAGE_TAG ${{ env.REGISTRY_HZ }}/mybigpai/pairag:$IMAGE_TAG From ff9ab56a54ee6976e14e9dbd184aaab0d129211e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Fri, 7 Jun 2024 15:36:16 +0800 Subject: [PATCH 07/17] Fix hologres check --- src/pai_rag/app/web/tabs/vector_db_panel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pai_rag/app/web/tabs/vector_db_panel.py b/src/pai_rag/app/web/tabs/vector_db_panel.py index ea3e175b..fe9da895 100644 --- a/src/pai_rag/app/web/tabs/vector_db_panel.py +++ b/src/pai_rag/app/web/tabs/vector_db_panel.py @@ -113,9 +113,9 @@ def create_vector_db_panel( label="Table", elem_id="hologres_table", ) - hologres_pre_delete = gr.Dropdown( - ["True", "False"], - label="Pre Delete", + hologres_pre_delete = gr.Checkbox( + label="Yes", + info="Clear hologres table on connection.", elem_id="hologres_pre_delete", ) From 23c06ab5a2ecd27c35a8d2c69035a9c8536b2e8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 10:04:10 +0800 Subject: [PATCH 08/17] Update registry --- src/pai_rag/core/rag_application.py | 63 ++- src/pai_rag/data/rag_datapipeline.py | 15 +- src/pai_rag/modules/__init__.py | 5 + .../modules/base/configurable_module.py | 26 +- src/pai_rag/modules/cache/oss_cache.py | 20 + src/pai_rag/modules/datareader/data_loader.py | 25 + src/pai_rag/modules/index/index.py | 17 +- src/pai_rag/modules/index/store.py | 21 +- src/pai_rag/modules/module_registry.py | 54 +- .../retriever/my_elasticsearch_store.py | 532 ++++++++++++++++++ src/pai_rag/modules/retriever/retriever.py | 35 +- 11 files changed, 714 insertions(+), 99 deletions(-) create mode 100644 src/pai_rag/modules/cache/oss_cache.py create mode 100644 src/pai_rag/modules/datareader/data_loader.py create mode 100644 src/pai_rag/modules/retriever/my_elasticsearch_store.py diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 8a842ca6..be704a9c 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -1,5 +1,3 @@ -from pai_rag.data.rag_dataloader import RagDataLoader -from pai_rag.utils.oss_cache import OssCache from pai_rag.modules.module_registry import module_registry from pai_rag.evaluations.batch_evaluator import BatchEvaluator from pai_rag.app.api.models import ( @@ -29,28 +27,7 @@ def __init__(self): def initialize(self, config): self.config = config - module_registry.init_modules(self.config) - self.index = module_registry.get_module("IndexModule") - self.llm = module_registry.get_module("LlmModule") - self.retriever = module_registry.get_module("RetrieverModule") - self.chat_store = module_registry.get_module("ChatStoreModule") - self.query_engine = module_registry.get_module("QueryEngineModule") - self.chat_engine_factory = module_registry.get_module("ChatEngineFactoryModule") - self.llm_chat_engine_factory = module_registry.get_module( - "LlmChatEngineFactoryModule" - ) - self.data_reader_factory = module_registry.get_module("DataReaderFactoryModule") - self.agent = module_registry.get_module("AgentModule") - - oss_cache = None - if config.get("oss_cache", None): - oss_cache = OssCache(config.oss_cache) - node_parser = module_registry.get_module("NodeParserModule") - - self.data_loader = RagDataLoader( - self.data_reader_factory, node_parser, self.index, oss_cache - ) self.logger.info("RagApplication initialized successfully.") def reload(self, config): @@ -59,14 +36,21 @@ def reload(self, config): # TODO: 大量文件上传实现异步添加 async def load_knowledge(self, file_dir, enable_qa_extraction=False): - await self.data_loader.load(file_dir, enable_qa_extraction) + data_loader = module_registry.get_module_with_config( + "DataLoaderModule", self.config + ) + await data_loader.load(file_dir, enable_qa_extraction) async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: return RetrievalResponse(docs=[]) query_bundle = QueryBundle(query.question) - node_results = await self.query_engine.aretrieve(query_bundle) + + query_engine = module_registry.get_module_with_config( + "QueryEngineModule", self.config + ) + node_results = await query_engine.aretrieve(query_bundle) docs = [ ContextDoc( @@ -96,11 +80,18 @@ async def aquery(self, query: RagQuery) -> RagResponse: answer="Empty query. Please input your question.", session_id=session_id ) - query_chat_engine = self.chat_engine_factory.get_chat_engine( + chat_engine_factory = module_registry.get_module_with_config( + "ChatEngineFactoryModule", self.config + ) + query_chat_engine = chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await query_chat_engine.achat(query.question) - self.chat_store.persist() + + chat_store = module_registry.get_module_with_config( + "ChatStoreModule", self.config + ) + chat_store.persist() return RagResponse(answer=response.response, session_id=session_id) async def aquery_llm(self, query: LlmQuery) -> LlmResponse: @@ -122,11 +113,18 @@ async def aquery_llm(self, query: LlmQuery) -> LlmResponse: answer="Empty query. Please input your question.", session_id=session_id ) - llm_chat_engine = self.llm_chat_engine_factory.get_chat_engine( + llm_chat_engine_factory = module_registry.get_module_with_config( + "LlmChatEngineFactoryModule", self.config + ) + llm_chat_engine = llm_chat_engine_factory.get_chat_engine( session_id, query.chat_history ) response = await llm_chat_engine.achat(query.question) - self.chat_store.persist() + + chat_store = module_registry.get_module_with_config( + "ChatStoreModule", self.config + ) + chat_store.persist() return LlmResponse(answer=response.response, session_id=session_id) async def aquery_agent(self, query: LlmQuery) -> LlmResponse: @@ -143,11 +141,14 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse: if not query.question: return LlmResponse(answer="Empty query. Please input your question.") - response = await self.agent.achat(query.question) + agent = module_registry.get_module_with_config("AgentModule", self.config) + response = await agent.achat(query.question) return LlmResponse(answer=response.response) async def batch_evaluate_retrieval_and_response(self, type): - batch_eval = BatchEvaluator(self.config, self.retriever, self.query_engine) + retriever = module_registry.get_module_with_config("RetrieverModule") + query_engine = module_registry.get_module_with_config("QueryEngineModule") + batch_eval = BatchEvaluator(self.config, retriever, query_engine) df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation( type=type, workers=2, save_to_file=True ) diff --git a/src/pai_rag/data/rag_datapipeline.py b/src/pai_rag/data/rag_datapipeline.py index feb0f750..23787e3b 100644 --- a/src/pai_rag/data/rag_datapipeline.py +++ b/src/pai_rag/data/rag_datapipeline.py @@ -2,14 +2,12 @@ import click import os from pathlib import Path -from pai_rag.data.rag_dataloader import RagDataLoader from pai_rag.core.rag_configuration import RagConfiguration -from pai_rag.utils.oss_cache import OssCache from pai_rag.modules.module_registry import module_registry class RagDataPipeline: - def __init__(self, data_loader: RagDataLoader): + def __init__(self, data_loader): self.data_loader = data_loader async def ingest_from_folder(self, folder_path: str, enable_qa_extraction: bool): @@ -23,16 +21,7 @@ def __init_data_pipeline(use_local_qa_model): config = RagConfiguration.from_file(config_file).get_value() module_registry.init_modules(config) - oss_cache = None - if config.get("oss_cache", None): - oss_cache = OssCache(config.oss_cache) - node_parser = module_registry.get_module("NodeParserModule") - index = module_registry.get_module("IndexModule") - data_reader_factory = module_registry.get_module("DataReaderFactoryModule") - - data_loader = RagDataLoader( - data_reader_factory, node_parser, index, oss_cache, use_local_qa_model - ) + data_loader = module_registry.get_module_with_config("DataLoaderModule", config) return RagDataPipeline(data_loader) diff --git a/src/pai_rag/modules/__init__.py b/src/pai_rag/modules/__init__.py index a23c466b..3491bb86 100644 --- a/src/pai_rag/modules/__init__.py +++ b/src/pai_rag/modules/__init__.py @@ -1,5 +1,6 @@ from pai_rag.modules.embedding.embedding import EmbeddingModule from pai_rag.modules.llm.llm_module import LlmModule +from pai_rag.modules.datareader.data_loader import DataLoaderModule from pai_rag.modules.datareader.datareader_factory import DataReaderFactoryModule from pai_rag.modules.index.index import IndexModule from pai_rag.modules.nodeparser.node_parser import NodeParserModule @@ -12,10 +13,13 @@ from pai_rag.modules.chat.chat_store import ChatStoreModule from pai_rag.modules.agent.agent import AgentModule from pai_rag.modules.tool.tool import ToolModule +from pai_rag.modules.cache.oss_cache import OssCacheModule + ALL_MODULES = [ "EmbeddingModule", "LlmModule", + "DataLoaderModule", "DataReaderFactoryModule", "IndexModule", "NodeParserModule", @@ -28,6 +32,7 @@ "LlmChatEngineFactoryModule", "AgentModule", "ToolModule", + "OssCacheModule", ] __all__ = ALL_MODULES + ["ALL_MODULES"] diff --git a/src/pai_rag/modules/base/configurable_module.py b/src/pai_rag/modules/base/configurable_module.py index 78e9b089..de9e3834 100644 --- a/src/pai_rag/modules/base/configurable_module.py +++ b/src/pai_rag/modules/base/configurable_module.py @@ -1,19 +1,19 @@ from abc import ABC, abstractmethod from typing import Dict, List, Any +import logging DEFAULT_INSTANCE_KEY = "__DEFAULT_INSTANCE__" +logger = logging.getLogger(__name__) + + class ConfigurableModule(ABC): """Configurable Module Helps to create instances according to configuration. """ - def __init__(self): - self.__params_map = {} - self.__instance_map = {} - @abstractmethod def _create_new_instance(self, new_params: Dict[str, Any]): raise NotImplementedError @@ -24,20 +24,4 @@ def get_dependencies() -> List[str]: raise NotImplementedError def get_or_create(self, new_params: Dict[str, Any]): - return self.get_or_create_by_name(new_params=new_params) - - def get_or_create_by_name( - self, new_params: Dict[str, Any], name: str = DEFAULT_INSTANCE_KEY - ): - # Create new instance when initializing or config changed. - if ( - self.__params_map.get(name, None) is None - or self.__params_map[name] != new_params - ): - print(f"{self.__class__.__name__} param changed, updating") - self.__instance_map[name] = self._create_new_instance(new_params) - self.__params_map[name] = new_params - else: - print(f"{self.__class__.__name__} param unchanged, skipping") - - return self.__instance_map[name] + return self._create_new_instance(new_params) diff --git a/src/pai_rag/modules/cache/oss_cache.py b/src/pai_rag/modules/cache/oss_cache.py new file mode 100644 index 00000000..d4a543b9 --- /dev/null +++ b/src/pai_rag/modules/cache/oss_cache.py @@ -0,0 +1,20 @@ +from typing import Any, Dict, List +from pai_rag.utils.oss_cache import OssCache +from pai_rag.modules.base.configurable_module import ConfigurableModule +from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG +import logging + +logger = logging.getLogger(__name__) + + +class OssCacheModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return [] + + def _create_new_instance(self, new_params: Dict[str, Any]): + cache_config = new_params[MODULE_PARAM_CONFIG] + if cache_config: + return OssCache(cache_config) + else: + return None diff --git a/src/pai_rag/modules/datareader/data_loader.py b/src/pai_rag/modules/datareader/data_loader.py new file mode 100644 index 00000000..5e5e0383 --- /dev/null +++ b/src/pai_rag/modules/datareader/data_loader.py @@ -0,0 +1,25 @@ +from typing import Any, Dict, List +from pai_rag.modules.base.configurable_module import ConfigurableModule +from pai_rag.data.rag_dataloader import RagDataLoader +import logging + +logger = logging.getLogger(__name__) + + +class DataLoaderModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return [ + "OssCacheModule", + "DataReaderFactoryModule", + "NodeParserModule", + "IndexModule", + ] + + def _create_new_instance(self, new_params: Dict[str, Any]): + oss_cache = new_params["OssCacheModule"] + data_reader_factory = new_params["DataReaderFactoryModule"] + node_parser = new_params["NodeParserModule"] + index = new_params["IndexModule"] + + return RagDataLoader(data_reader_factory, node_parser, index, oss_cache) diff --git a/src/pai_rag/modules/index/index.py b/src/pai_rag/modules/index/index.py index 248fbc96..4221193b 100644 --- a/src/pai_rag/modules/index/index.py +++ b/src/pai_rag/modules/index/index.py @@ -9,6 +9,7 @@ from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.modules.index.store import RagStore +from llama_index.vector_stores.faiss import FaissVectorStore from pai_rag.utils.store_utils import get_store_persist_directory_name, store_path logging.basicConfig( @@ -18,7 +19,6 @@ ) DEFAULT_PERSIST_DIR = "./storage" -INDEX_STATE_FILE = "index.state.json" class IndexModule(ConfigurableModule): @@ -60,16 +60,19 @@ def create_indices(self): logging.info("Empty index, need to create indices.") vector_index = VectorStoreIndex( - nodes=[], - storage_context=self.storage_context, - embed_model=self.embed_model, - store_nodes_override=True, + nodes=[], storage_context=self.storage_context, embed_model=self.embed_model ) logging.info("Created vector_index.") return vector_index def load_indices(self): - vector_index = load_index_from_storage(storage_context=self.storage_context) - + if isinstance(self.storage_context.vector_store, FaissVectorStore): + vector_index = load_index_from_storage(storage_context=self.storage_context) + else: + vector_index = VectorStoreIndex( + nodes=[], + storage_context=self.storage_context, + embed_model=self.embed_model, + ) return vector_index diff --git a/src/pai_rag/modules/index/store.py b/src/pai_rag/modules/index/store.py index 1a290cdd..09702055 100644 --- a/src/pai_rag/modules/index/store.py +++ b/src/pai_rag/modules/index/store.py @@ -6,7 +6,6 @@ from llama_index.vector_stores.analyticdb import AnalyticDBVectorStore from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.chroma import ChromaVectorStore -from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.milvus import MilvusVectorStore from elasticsearch.helpers.vectorstore import AsyncDenseVectorStrategy @@ -16,6 +15,8 @@ from llama_index.core import StorageContext import logging +from pai_rag.modules.retriever.my_elasticsearch_store import MyElasticsearchStore + DEFAULT_CHROMA_COLLECTION_NAME = "pairag" logger = logging.getLogger(__name__) @@ -36,14 +37,19 @@ def _get_or_create_storage_context(self): self.vector_store = None self.doc_store = None self.index_store = None + persist_dir = None vector_store_type = ( - self.store_config["vector_store"].get("type", "chroma").lower() + self.store_config["vector_store"].get("type", "faiss").lower() ) + if vector_store_type == "chroma": self.vector_store = self._get_or_create_chroma() logger.info("initialized Chroma vector store.") elif vector_store_type == "faiss": + self.doc_store = self._get_or_create_simple_doc_store() + self.index_store = self._get_or_create_simple_index_store() + persist_dir = self.persist_dir self.vector_store = self._get_or_create_faiss() logger.info("initialized FAISS vector store.") elif vector_store_type == "hologres": @@ -60,14 +66,11 @@ def _get_or_create_storage_context(self): else: raise ValueError(f"Unknown vector_store type '{vector_store_type}'.") - self.doc_store = self._get_or_create_simple_doc_store() - self.index_store = self._get_or_create_simple_index_store() - storage_context = StorageContext.from_defaults( docstore=self.doc_store, index_store=self.index_store, vector_store=self.vector_store, - persist_dir=self.persist_dir, + persist_dir=persist_dir, ) return storage_context @@ -124,12 +127,14 @@ def _get_or_create_adb(self): def _get_or_create_es(self): es_config = self.store_config["vector_store"] - return ElasticsearchStore( + return MyElasticsearchStore( index_name=es_config["es_index"], es_url=es_config["es_url"], es_user=es_config["es_user"], es_password=es_config["es_password"], - retrieval_strategy=AsyncDenseVectorStrategy(hybrid=True), + retrieval_strategy=AsyncDenseVectorStrategy( + hybrid=True, rrf={"window_size": 50} + ), ) def _get_or_create_milvus(self): diff --git a/src/pai_rag/modules/module_registry.py b/src/pai_rag/modules/module_registry.py index 12c65ce3..cc2dca46 100644 --- a/src/pai_rag/modules/module_registry.py +++ b/src/pai_rag/modules/module_registry.py @@ -1,5 +1,8 @@ +import hashlib +from typing import Dict, Any from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG import pai_rag.modules as modules +import logging MODULE_CONFIG_KEY_MAP = { "IndexModule": "index", @@ -16,17 +19,25 @@ "DataReaderFactoryModule": "data_reader", "AgentModule": "agent", "ToolModule": "tool", + "DataLoaderModule": "data_loader", + "OssCacheModule": "cache", } +logger = logging.getLogger(__name__) + + class ModuleRegistry: def __init__(self): - self._mod_instance_map = {} self._mod_cls_map = {} self._mod_deps_map = {} self._mod_deps_map_inverted = {} + self._mod_instance_map = {} + for m_name in modules.ALL_MODULES: + self._mod_instance_map[m_name] = {} + m_cls = getattr(modules, m_name) self._mod_cls_map[m_name] = m_cls() @@ -38,12 +49,16 @@ def __init__(self): self._mod_deps_map_inverted[dep] = [] self._mod_deps_map_inverted[dep].append(m_name) - def get_module(self, module_key: str): - return self._mod_instance_map[module_key] + def _get_param_hash(self, params: Dict[str, Any]): + repr_str = repr(sorted(params.items())).encode("utf-8") + return hashlib.sha256(repr_str).hexdigest() + + def get_module_with_config(self, module_key, config): + return self._create_mod_lazily(module_key, config) def init_modules(self, config): + mod_cache = {} mod_stack = [] - mods_inited = [] mod_ref_count = {} for mod, deps in self._mod_deps_map.items(): ref_count = len(deps) @@ -53,9 +68,8 @@ def init_modules(self, config): while mod_stack: mod = mod_stack.pop() - mod_obj = self._init_mod(mod, config) - mods_inited.append(mod) - self._mod_instance_map[mod] = mod_obj + mod_obj = self._create_mod_lazily(mod, config, mod_cache) + mod_cache[mod] = mod_obj # update module ref count that depends on on ref_mods = self._mod_deps_map_inverted.get(mod, []) @@ -64,23 +78,35 @@ def init_modules(self, config): if mod_ref_count[ref_mod] == 0: mod_stack.append(ref_mod) - if len(mods_inited) != len(modules.ALL_MODULES): + if len(mod_cache) != len(modules.ALL_MODULES): # dependency circular error! raise ValueError( - f"Circular dependency detected. Please check module dependency configuration. Module initialized: {mods_inited}. Module ref count: {mod_ref_count}" + f"Circular dependency detected. Please check module dependency configuration. Module initialized: {mod_cache}. Module ref count: {mod_ref_count}" ) - print(f"RAG modules init successfully. {mods_inited}") + logger.info(f"RAG modules init successfully. {mod_cache.keys()}") return - def _init_mod(self, mod_name, config): + def _create_mod_lazily(self, mod_name, config, mod_cache=None): + if mod_cache and mod_name in mod_cache: + return mod_cache[mod_name] + + logger.info(f"Get module {mod_name}.") + mod_config_key = MODULE_CONFIG_KEY_MAP[mod_name] mod_deps = self._mod_deps_map[mod_name] mod_cls = self._mod_cls_map[mod_name] - params = {MODULE_PARAM_CONFIG: config[mod_config_key]} + params = {MODULE_PARAM_CONFIG: config.get(mod_config_key, None)} for dep in mod_deps: - params[dep] = self.get_module(dep) - return mod_cls.get_or_create(params) + params[dep] = self._create_mod_lazily(dep, config, mod_cache) + + instance_key = self._get_param_hash(params) + if instance_key not in self._mod_instance_map[mod_name]: + logger.info(f"Creating new instance for module {mod_name} {instance_key}.") + self._mod_instance_map[mod_name][instance_key] = mod_cls.get_or_create( + params + ) + return self._mod_instance_map[mod_name][instance_key] module_registry = ModuleRegistry() diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py new file mode 100644 index 00000000..cb68d4ec --- /dev/null +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -0,0 +1,532 @@ +"""Elasticsearch vector store.""" + +import asyncio +from logging import getLogger +from typing import Any, Callable, Dict, List, Literal, Optional, Union + +import nest_asyncio +import numpy as np +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) +from elasticsearch.helpers.vectorstore import AsyncVectorStore +from elasticsearch.helpers.vectorstore import ( + AsyncBM25Strategy, + AsyncSparseVectorStrategy, + AsyncDenseVectorStrategy, + AsyncRetrievalStrategy, + DistanceMetric, +) + +from llama_index.vector_stores.elasticsearch.utils import ( + get_elasticsearch_client, + get_user_agent, +) + +logger = getLogger(__name__) + +DISTANCE_STRATEGIES = Literal[ + "COSINE", + "DOT_PRODUCT", + "EUCLIDEAN_DISTANCE", +] + + +def _to_elasticsearch_filter(standard_filters: MetadataFilters) -> Dict[str, Any]: + """ + Convert standard filters to Elasticsearch filter. + + Args: + standard_filters: Standard Llama-index filters. + + Returns: + Elasticsearch filter. + """ + if len(standard_filters.legacy_filters()) == 1: + filter = standard_filters.legacy_filters()[0] + return { + "term": { + f"metadata.{filter.key}.keyword": { + "value": filter.value, + } + } + } + else: + operands = [] + for filter in standard_filters.legacy_filters(): + operands.append( + { + "term": { + f"metadata.{filter.key}.keyword": { + "value": filter.value, + } + } + } + ) + return {"bool": {"must": operands}} + + +def _to_llama_similarities(scores: List[float]) -> List[float]: + if scores is None or len(scores) == 0: + return [] + + scores_to_norm: np.ndarray = np.array(scores) + return np.exp(scores_to_norm - np.max(scores_to_norm)).tolist() + + +def _mode_must_match_retrieval_strategy( + mode: VectorStoreQueryMode, retrieval_strategy: AsyncRetrievalStrategy +) -> None: + """ + Different retrieval strategies require different ways of indexing that must be known at the + time of adding data. The query mode is known at query time. This function checks if the + retrieval strategy (and way of indexing) is compatible with the query mode and raises and + exception in the case of a mismatch. + """ + if mode == VectorStoreQueryMode.DEFAULT: + # it's fine to not specify an explicit other mode + return + + mode_retrieval_dict = { + VectorStoreQueryMode.SPARSE: AsyncSparseVectorStrategy, + VectorStoreQueryMode.TEXT_SEARCH: AsyncBM25Strategy, + VectorStoreQueryMode.HYBRID: AsyncDenseVectorStrategy, + } + + required_strategy = mode_retrieval_dict.get(mode) + if not required_strategy: + raise NotImplementedError(f"query mode {mode} currently not supported") + + if not isinstance(retrieval_strategy, required_strategy): + raise ValueError( + f"query mode {mode} incompatible with retrieval strategy {type(retrieval_strategy)}, " + f"expected {required_strategy}" + ) + + if mode == VectorStoreQueryMode.HYBRID and not retrieval_strategy.hybrid: + raise ValueError("to enable hybrid mode, it must be set in retrieval strategy") + + +class MyElasticsearchStore(BasePydanticVectorStore): + """ + Elasticsearch vector store. + + Args: + index_name: Name of the Elasticsearch index. + es_client: Optional. Pre-existing AsyncElasticsearch client. + es_url: Optional. Elasticsearch URL. + es_cloud_id: Optional. Elasticsearch cloud ID. + es_api_key: Optional. Elasticsearch API key. + es_user: Optional. Elasticsearch username. + es_password: Optional. Elasticsearch password. + text_field: Optional. Name of the Elasticsearch field that stores the text. + vector_field: Optional. Name of the Elasticsearch field that stores the + embedding. + batch_size: Optional. Batch size for bulk indexing. Defaults to 200. + distance_strategy: Optional. Distance strategy to use for similarity search. + Defaults to "COSINE". + retrieval_strategy: Retrieval strategy to use. AsyncBM25Strategy / + AsyncSparseVectorStrategy / AsyncDenseVectorStrategy / AsyncRetrievalStrategy. + Defaults to AsyncDenseVectorStrategy. + + Raises: + ConnectionError: If AsyncElasticsearch client cannot connect to Elasticsearch. + ValueError: If neither es_client nor es_url nor es_cloud_id is provided. + + Examples: + `pip install llama-index-vector-stores-elasticsearch` + + ```python + from llama_index.vector_stores import ElasticsearchStore + + # Additional setup for ElasticsearchStore class + index_name = "my_index" + es_url = "http://localhost:9200" + es_cloud_id = "" # Found within the deployment page + es_user = "elastic" + es_password = "" # Provided when creating deployment or can be reset + es_api_key = "" # Create an API key within Kibana (Security -> API Keys) + + # Connecting to ElasticsearchStore locally + es_local = ElasticsearchStore( + index_name=index_name, + es_url=es_url, + ) + + # Connecting to Elastic Cloud with username and password + es_cloud_user_pass = ElasticsearchStore( + index_name=index_name, + es_cloud_id=es_cloud_id, + es_user=es_user, + es_password=es_password, + ) + + # Connecting to Elastic Cloud with API Key + es_cloud_api_key = ElasticsearchStore( + index_name=index_name, + es_cloud_id=es_cloud_id, + es_api_key=es_api_key, + ) + ``` + + """ + + class Config: + # allow pydantic to tolarate its inability to validate AsyncRetrievalStrategy + arbitrary_types_allowed = True + + stores_text: bool = True + index_name: str + es_client: Optional[Any] + es_url: Optional[str] + es_cloud_id: Optional[str] + es_api_key: Optional[str] + es_user: Optional[str] + es_password: Optional[str] + text_field: str = "content" + vector_field: str = "embedding" + batch_size: int = 200 + distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE" + retrieval_strategy: AsyncRetrievalStrategy + + _store = PrivateAttr() + + def __init__( + self, + index_name: str, + es_client: Optional[Any] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_api_key: Optional[str] = None, + es_user: Optional[str] = None, + es_password: Optional[str] = None, + text_field: str = "content", + vector_field: str = "embedding", + batch_size: int = 200, + distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE", + retrieval_strategy: Optional[AsyncRetrievalStrategy] = None, + ) -> None: + nest_asyncio.apply() + + if not es_client: + es_client = get_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + ) + + if retrieval_strategy is None: + retrieval_strategy = AsyncDenseVectorStrategy( + distance=DistanceMetric[distance_strategy] + ) + + metadata_mappings = { + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + } + + self._store = AsyncVectorStore( + user_agent=get_user_agent(), + client=es_client, + index=index_name, + retrieval_strategy=retrieval_strategy, + text_field=text_field, + vector_field=vector_field, + metadata_mappings=metadata_mappings, + ) + + super().__init__( + index_name=index_name, + es_client=es_client, + es_url=es_url, + es_cloud_id=es_cloud_id, + es_api_key=es_api_key, + es_user=es_user, + es_password=es_password, + text_field=text_field, + vector_field=vector_field, + batch_size=batch_size, + distance_strategy=distance_strategy, + retrieval_strategy=retrieval_strategy, + ) + + @property + def client(self) -> Any: + """Get async elasticsearch client.""" + return self._store.client + + def close(self) -> None: + return asyncio.get_event_loop().run_until_complete(self._store.close()) + + def add( + self, + nodes: List[BaseNode], + *, + create_index_if_not_exists: bool = True, + **add_kwargs: Any, + ) -> List[str]: + """ + Add nodes to Elasticsearch index. + + Args: + nodes: List of nodes with embeddings. + create_index_if_not_exists: Optional. Whether to create + the Elasticsearch index if it + doesn't already exist. + Defaults to True. + + Returns: + List of node IDs that were added to the index. + + Raises: + ImportError: If elasticsearch['async'] python package is not installed. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.async_add(nodes, create_index_if_not_exists=create_index_if_not_exists) + ) + + async def async_add( + self, + nodes: List[BaseNode], + *, + create_index_if_not_exists: bool = True, + **add_kwargs: Any, + ) -> List[str]: + """ + Asynchronous method to add nodes to Elasticsearch index. + + Args: + nodes: List of nodes with embeddings. + create_index_if_not_exists: Optional. Whether to create + the AsyncElasticsearch index if it + doesn't already exist. + Defaults to True. + + Returns: + List of node IDs that were added to the index. + + Raises: + ImportError: If elasticsearch python package is not installed. + BulkIndexError: If AsyncElasticsearch async_bulk indexing fails. + """ + if len(nodes) == 0: + return [] + + embeddings: List[List[float]] = [] + texts: List[str] = [] + metadatas: List[dict] = [] + ids: List[str] = [] + for node in nodes: + ids.append(node.node_id) + embeddings.append(node.get_embedding()) + texts.append(node.get_content(metadata_mode=MetadataMode.NONE)) + metadatas.append(node_to_metadata_dict(node, remove_text=True)) + + if not self._store.num_dimensions: + self._store.num_dimensions = len(embeddings[0]) + + return await self._store.add_texts( + texts=texts, + metadatas=metadatas, + vectors=embeddings, + ids=ids, + create_index_if_not_exists=create_index_if_not_exists, + bulk_kwargs=add_kwargs, + ) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Delete node from Elasticsearch index. + + Args: + ref_doc_id: ID of the node to delete. + delete_kwargs: Optional. Additional arguments to + pass to Elasticsearch delete_by_query. + + Raises: + Exception: If Elasticsearch delete_by_query fails. + """ + return asyncio.get_event_loop().run_until_complete( + self.adelete(ref_doc_id, **delete_kwargs) + ) + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """ + Async delete node from Elasticsearch index. + + Args: + ref_doc_id: ID of the node to delete. + delete_kwargs: Optional. Additional arguments to + pass to AsyncElasticsearch delete_by_query. + + Raises: + Exception: If AsyncElasticsearch delete_by_query fails. + """ + await self._store.delete( + query={"term": {"metadata.ref_doc_id": ref_doc_id}}, **delete_kwargs + ) + + def query( + self, + query: VectorStoreQuery, + custom_query: Optional[ + Callable[[Dict, Union[VectorStoreQuery, None]], Dict] + ] = None, + es_filter: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """ + Query index for top k most similar nodes. + + Args: + query_embedding (List[float]): query embedding + custom_query: Optional. custom query function that takes in the es query + body and returns a modified query body. + This can be used to add additional query + parameters to the Elasticsearch query. + es_filter: Optional. Elasticsearch filter to apply to the + query. If filter is provided in the query, + this filter will be ignored. + + Returns: + VectorStoreQueryResult: Result of the query. + + Raises: + Exception: If Elasticsearch query fails. + + """ + return asyncio.get_event_loop().run_until_complete( + self.aquery(query, custom_query, es_filter, **kwargs) + ) + + async def aquery( + self, + query: VectorStoreQuery, + custom_query: Optional[ + Callable[[Dict, Union[VectorStoreQuery, None]], Dict] + ] = None, + es_filter: Optional[List[Dict]] = None, + **kwargs: Any, + ) -> VectorStoreQueryResult: + """ + Asynchronous query index for top k most similar nodes. + + Args: + query_embedding (VectorStoreQuery): query embedding + custom_query: Optional. custom query function that takes in the es query + body and returns a modified query body. + This can be used to add additional query + parameters to the AsyncElasticsearch query. + es_filter: Optional. AsyncElasticsearch filter to apply to the + query. If filter is provided in the query, + this filter will be ignored. + + Returns: + VectorStoreQueryResult: Result of the query. + + Raises: + Exception: If AsyncElasticsearch query fails. + + """ + print("================: ", query.mode) + if query.mode == VectorStoreQueryMode.HYBRID: + retrieval_strategy = AsyncDenseVectorStrategy( + hybrid=True, rrf={"window_size": 50} + ) + elif query.mode == VectorStoreQueryMode.TEXT_SEARCH: + retrieval_strategy = AsyncBM25Strategy() + else: + retrieval_strategy = AsyncDenseVectorStrategy() + + metadata_mappings = { + "document_id": {"type": "keyword"}, + "doc_id": {"type": "keyword"}, + "ref_doc_id": {"type": "keyword"}, + } + self._store = AsyncVectorStore( + user_agent=get_user_agent(), + client=self.es_client, + index=self.index_name, + retrieval_strategy=retrieval_strategy, + text_field=self.text_field, + vector_field=self.vector_field, + metadata_mappings=metadata_mappings, + ) + + if query.filters is not None and len(query.filters.legacy_filters()) > 0: + filter = [_to_elasticsearch_filter(query.filters)] + else: + filter = es_filter or [] + + hits = await self._store.search( + query=query.query_str, + query_vector=query.query_embedding, + k=query.similarity_top_k, + num_candidates=query.similarity_top_k * 10, + filter=filter, + custom_query=custom_query, + ) + + top_k_nodes = [] + top_k_ids = [] + top_k_scores = [] + for hit in hits: + source = hit["_source"] + metadata = source.get("metadata", None) + text = source.get(self.text_field, None) + node_id = hit["_id"] + + try: + node = metadata_dict_to_node(metadata) + node.text = text + except Exception: + # Legacy support for old metadata format + logger.warning( + f"Could not parse metadata from hit {hit['_source']['metadata']}" + ) + node_info = source.get("node_info") + relationships = source.get("relationships", {}) + start_char_idx = None + end_char_idx = None + if isinstance(node_info, dict): + start_char_idx = node_info.get("start", None) + end_char_idx = node_info.get("end", None) + + node = TextNode( + text=text, + metadata=metadata, + id_=node_id, + start_char_idx=start_char_idx, + end_char_idx=end_char_idx, + relationships=relationships, + ) + top_k_nodes.append(node) + top_k_ids.append(node_id) + top_k_scores.append(hit.get("_rank", hit["_score"])) + + if ( + isinstance(self.retrieval_strategy, AsyncDenseVectorStrategy) + and self.retrieval_strategy.hybrid + ): + total_rank = sum(top_k_scores) + top_k_scores = [total_rank - rank / total_rank for rank in top_k_scores] + + return VectorStoreQueryResult( + nodes=top_k_nodes, + ids=top_k_ids, + similarities=_to_llama_similarities(top_k_scores), + ) diff --git a/src/pai_rag/modules/retriever/retriever.py b/src/pai_rag/modules/retriever/retriever.py index bc7cbca3..77f3c153 100644 --- a/src/pai_rag/modules/retriever/retriever.py +++ b/src/pai_rag/modules/retriever/retriever.py @@ -10,12 +10,14 @@ from llama_index.core.tools import RetrieverTool from llama_index.core.selectors import LLMSingleSelector from llama_index.core.retrievers import RouterRetriever +from llama_index.core.vector_stores.types import VectorStoreQueryMode # from llama_index.retrievers.bm25 import BM25Retriever from pai_rag.integrations.retrievers.bm25 import BM25Retriever from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.utils.prompt_template import QUERY_GEN_PROMPT +from pai_rag.modules.retriever.my_elasticsearch_store import MyElasticsearchStore from pai_rag.modules.retriever.my_vector_index_retriever import MyVectorIndexRetriever logger = logging.getLogger(__name__) @@ -23,8 +25,15 @@ stopword_list = stopwords.words("chinese") + stopwords.words("english") +## PUT in utils file and add stopword in TRIE structure. def jieba_tokenize(text: str) -> List[str]: - return [w for w in jieba.lcut(text) if w not in stopword_list] + tokens = [] + for w in jieba.lcut(text): + token = w.lower() + if token not in stopword_list: + tokens.append(token) + + return tokens class RetrieverModule(ConfigurableModule): @@ -37,11 +46,27 @@ def _create_new_instance(self, new_params: Dict[str, Any]): vector_index = new_params["IndexModule"] similarity_top_k = config.get("similarity_top_k", 5) - # vector + + retrieval_mode = config.get("retrieval_mode", "hybrid").lower() + + # Special handle elastic search + if isinstance(vector_index.storage_context.vector_store, MyElasticsearchStore): + if retrieval_mode == "embedding": + query_mode = VectorStoreQueryMode.DEFAULT + elif retrieval_mode == "keyword": + query_mode = VectorStoreQueryMode.TEXT_SEARCH + else: + query_mode = VectorStoreQueryMode.HYBRID + + return MyVectorIndexRetriever( + index=vector_index, + similarity_top_k=similarity_top_k, + vector_store_query_mode=query_mode, + ) + vector_retriever = MyVectorIndexRetriever( index=vector_index, similarity_top_k=similarity_top_k ) - # keyword bm25_retriever = BM25Retriever.from_defaults( index=vector_index, @@ -49,11 +74,11 @@ def _create_new_instance(self, new_params: Dict[str, Any]): tokenizer=jieba_tokenize, ) - if config["retrieval_mode"] == "embedding": + if retrieval_mode == "embedding": logger.info(f"MyVectorIndexRetriever used with top_k {similarity_top_k}.") return vector_retriever - elif config["retrieval_mode"] == "keyword": + elif retrieval_mode == "keyword": logger.info(f"BM25Retriever used with top_k {similarity_top_k}.") return bm25_retriever From ec97b641bfab85d812b8f69ad603257745858912 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 10:40:03 +0800 Subject: [PATCH 09/17] Fix bug --- src/pai_rag/core/rag_application.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index be704a9c..2fc2f622 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -146,8 +146,12 @@ async def aquery_agent(self, query: LlmQuery) -> LlmResponse: return LlmResponse(answer=response.response) async def batch_evaluate_retrieval_and_response(self, type): - retriever = module_registry.get_module_with_config("RetrieverModule") - query_engine = module_registry.get_module_with_config("QueryEngineModule") + retriever = module_registry.get_module_with_config( + "RetrieverModule", self.config + ) + query_engine = module_registry.get_module_with_config( + "QueryEngineModule", self.config + ) batch_eval = BatchEvaluator(self.config, retriever, query_engine) df, eval_res_avg = await batch_eval.batch_retrieval_response_aevaluation( type=type, workers=2, save_to_file=True From 5fb8e64b3c0eefb758f0afa673da2e227d135451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 11:37:48 +0800 Subject: [PATCH 10/17] Fix tests --- src/pai_rag/evaluations/batch_evaluator.py | 4 ++-- .../dataset_generation/generate_dataset.py | 23 ++++++++++++++++--- .../retriever/my_elasticsearch_store.py | 1 - tests/core/test_rag_application.py | 3 --- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/pai_rag/evaluations/batch_evaluator.py b/src/pai_rag/evaluations/batch_evaluator.py index 47684e25..abd5f142 100644 --- a/src/pai_rag/evaluations/batch_evaluator.py +++ b/src/pai_rag/evaluations/batch_evaluator.py @@ -189,8 +189,8 @@ def __init_evaluator_pipeline(): config = RagConfiguration.from_file(config_file).get_value() module_registry.init_modules(config) - retriever = module_registry.get_module("RetrieverModule") - query_engine = module_registry.get_module("QueryEngineModule") + retriever = module_registry.get_module_with_config("RetrieverModule", config) + query_engine = module_registry.get_module_with_config("QueryEngineModule", config) return BatchEvaluator(config, retriever, query_engine) diff --git a/src/pai_rag/evaluations/dataset_generation/generate_dataset.py b/src/pai_rag/evaluations/dataset_generation/generate_dataset.py index 3bfcddc0..02b43e60 100644 --- a/src/pai_rag/evaluations/dataset_generation/generate_dataset.py +++ b/src/pai_rag/evaluations/dataset_generation/generate_dataset.py @@ -1,5 +1,6 @@ import logging import os +from pathlib import Path from pai_rag.core.rag_configuration import RagConfiguration from pai_rag.modules.module_registry import module_registry from llama_index.core.prompts.prompt_type import PromptType @@ -16,8 +17,13 @@ DEFAULT_TEXT_QA_PROMPT_TMPL, DEFAULT_QUESTION_GENERATION_QUERY, ) + import json +_BASE_DIR = Path(__file__).parent.parent.parent +DEFAULT_EVAL_CONFIG_FILE = os.path.join(_BASE_DIR, "config/settings.toml") +DEFAULT_EVAL_DATA_FOLDER = "tests/testdata/paul_graham" + class GenerateDatasetPipeline(ModifiedRagDatasetGenerator): def __init__( @@ -29,11 +35,22 @@ def __init__( show_progress: Optional[bool] = True, ) -> None: self.name = "GenerateDatasetPipeline" - self.nodes = list( - module_registry.get_module("IndexModule").docstore.docs.values() + self.config = RagConfiguration.from_file(DEFAULT_EVAL_CONFIG_FILE).get_value() + + # load nodes + module_registry.init_modules(self.config) + datareader_factory = module_registry.get_module_with_config( + "DataReaderFactoryModule", self.config ) + self.node_parser = module_registry.get_module_with_config( + "NodeParserModule", self.config + ) + reader = datareader_factory.get_reader(DEFAULT_EVAL_DATA_FOLDER) + docs = reader.load_data() + self.nodes = self.node_parser.get_nodes_from_documents(docs) + self.num_questions_per_chunk = num_questions_per_chunk - self.llm = module_registry.get_module("LlmModule") + self.llm = module_registry.get_module_with_config("LlmModule", self.config) self.text_question_template = PromptTemplate(text_question_template_str) self.text_qa_template = PromptTemplate( text_qa_template_str, prompt_type=PromptType.QUESTION_ANSWER diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py index cb68d4ec..d076dcf4 100644 --- a/src/pai_rag/modules/retriever/my_elasticsearch_store.py +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -442,7 +442,6 @@ async def aquery( Exception: If AsyncElasticsearch query fails. """ - print("================: ", query.mode) if query.mode == VectorStoreQueryMode.HYBRID: retrieval_strategy = AsyncDenseVectorStrategy( hybrid=True, rrf={"window_size": 50} diff --git a/tests/core/test_rag_application.py b/tests/core/test_rag_application.py index 0e42a291..ef27858f 100644 --- a/tests/core/test_rag_application.py +++ b/tests/core/test_rag_application.py @@ -30,10 +30,7 @@ def rag_app(): # Test load knowledge file async def test_add_knowledge_file(rag_app: RagApplication): data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham") - print(len(rag_app.index.docstore.docs)) await rag_app.load_knowledge(data_dir) - print(len(rag_app.index.docstore.docs)) - assert len(rag_app.index.docstore.docs) > 0 # Test rag query From cd9dc826feeb108df22fd56957273d62285945bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 13:36:21 +0800 Subject: [PATCH 11/17] Add queue --- src/pai_rag/modules/index/index.py | 7 +- .../modules/index/my_vector_store_index.py | 143 ++++++++++++++++++ .../retriever/my_vector_index_retriever.py | 2 +- 3 files changed, 147 insertions(+), 5 deletions(-) create mode 100644 src/pai_rag/modules/index/my_vector_store_index.py diff --git a/src/pai_rag/modules/index/index.py b/src/pai_rag/modules/index/index.py index 4221193b..16498956 100644 --- a/src/pai_rag/modules/index/index.py +++ b/src/pai_rag/modules/index/index.py @@ -3,8 +3,7 @@ import sys from typing import Dict, List, Any -from llama_index.core import VectorStoreIndex - +from pai_rag.modules.index.my_vector_store_index import MyVectorStoreIndex from llama_index.core import load_index_from_storage from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG @@ -59,7 +58,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): def create_indices(self): logging.info("Empty index, need to create indices.") - vector_index = VectorStoreIndex( + vector_index = MyVectorStoreIndex( nodes=[], storage_context=self.storage_context, embed_model=self.embed_model ) logging.info("Created vector_index.") @@ -70,7 +69,7 @@ def load_indices(self): if isinstance(self.storage_context.vector_store, FaissVectorStore): vector_index = load_index_from_storage(storage_context=self.storage_context) else: - vector_index = VectorStoreIndex( + vector_index = MyVectorStoreIndex( nodes=[], storage_context=self.storage_context, embed_model=self.embed_model, diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py new file mode 100644 index 00000000..1a447c44 --- /dev/null +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -0,0 +1,143 @@ +"""Base vector store index. + +An index that is built on top of an existing vector store. + +""" + +import logging +from typing import Any, Sequence +from queue import Queue, Empty +import threading +from llama_index.core import VectorStoreIndex +from llama_index.core.data_structs.data_structs import IndexDict +from llama_index.core.schema import ( + BaseNode, + ImageNode, + IndexNode, +) +from llama_index.core.utils import iter_batch + +logger = logging.getLogger(__name__) + + +class MyVectorStoreIndex(VectorStoreIndex): + async def _async_add_nodes_to_index( + self, + index_struct: IndexDict, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **insert_kwargs: Any, + ) -> None: + """Asynchronously add nodes to index.""" + if not nodes: + return + + for nodes_batch in iter_batch(nodes, self._insert_batch_size): + nodes_batch = await self._aget_node_with_embedding( + nodes_batch, show_progress + ) + new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(nodes_batch, new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + else: + # NOTE: if the vector store keeps text, + # we only need to add image and index nodes + for node, new_id in zip(nodes_batch, new_ids): + if isinstance(node, (ImageNode, IndexNode)): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + + def _add_nodes_batch_to_index( + self, + q: Queue, + index_struct, + insert_kwargs, + ): + i = 0 + while True: + try: + nodes_batch = q.get(timeout=3) + except Empty: + continue + + if nodes_batch is None: + q.task_done() + return + + i += 1 + print(f"Consuming batch {i}, batch size {len(nodes_batch)}") + new_ids = self._vector_store.add(nodes_batch, **insert_kwargs) + + if not self._vector_store.stores_text or self._store_nodes_override: + print("saving to docstore!!!") + # NOTE: if the vector store doesn't store text, + # we need to add the nodes to the index struct and document store + for node, new_id in zip(nodes_batch, new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + else: + print("Skipping saving to docstore!!!") + # NOTE: if the vector store keeps text, + # we only need to add image and index nodes + for node, new_id in zip(nodes_batch, new_ids): + if isinstance(node, (ImageNode, IndexNode)): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + q.task_done() + + def _add_nodes_to_index( + self, + index_struct: IndexDict, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **insert_kwargs: Any, + ) -> None: + """Add document to index.""" + if not nodes: + return + + q = Queue(maxsize=100) + + work_thread = threading.Thread( + target=self._add_nodes_batch_to_index, args=(q, index_struct, insert_kwargs) + ) + work_thread.start() + + i = 0 + for nodes_batch in iter_batch(nodes, 100): + nodes_batch = self._get_node_with_embedding(nodes_batch, show_progress) + q.put(nodes_batch) + i += 1 + print(f"produced batch {i}, batch size {len(nodes_batch)}") + + q.put(None) + q.join() diff --git a/src/pai_rag/modules/retriever/my_vector_index_retriever.py b/src/pai_rag/modules/retriever/my_vector_index_retriever.py index 16c98946..755f02a6 100644 --- a/src/pai_rag/modules/retriever/my_vector_index_retriever.py +++ b/src/pai_rag/modules/retriever/my_vector_index_retriever.py @@ -25,7 +25,7 @@ class MyVectorIndexRetriever(VectorIndexRetriever): and return the results with the query_result.similarities sorted in descending order. Args: - index (VectorStoreIndex): vector store index. + index (MyVectorIndexRetriever): vector store index. similarity_top_k (int): number of top k results to return. vector_store_query_mode (str): vector store query mode See reference for VectorStoreQueryMode for full list of supported modes. From 8e5c7f696b838c2dfbc912a20cc9397a56c2bf5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 14:36:05 +0800 Subject: [PATCH 12/17] Update batch size --- src/pai_rag/modules/index/my_vector_store_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py index 1a447c44..54e68cb3 100644 --- a/src/pai_rag/modules/index/my_vector_store_index.py +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -133,7 +133,7 @@ def _add_nodes_to_index( work_thread.start() i = 0 - for nodes_batch in iter_batch(nodes, 100): + for nodes_batch in iter_batch(nodes, 500): nodes_batch = self._get_node_with_embedding(nodes_batch, show_progress) q.put(nodes_batch) i += 1 From 3a5155b8dcf9687c3310a7bed69896436d519155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 16:57:22 +0800 Subject: [PATCH 13/17] Add async interface --- src/pai_rag/data/rag_dataloader.py | 4 +- src/pai_rag/modules/embedding/embedding.py | 4 +- .../embedding/my_huggingface_embedding.py | 143 +++++++++++++++ src/pai_rag/modules/index/index.py | 5 + .../modules/index/my_vector_store_index.py | 167 +++++++----------- .../retriever/my_elasticsearch_store.py | 1 + 6 files changed, 216 insertions(+), 108 deletions(-) create mode 100644 src/pai_rag/modules/embedding/my_huggingface_embedding.py diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index a4d3c5a1..b3bcec7a 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -35,8 +35,8 @@ def __init__( ): self.datareader_factory = datareader_factory self.node_parser = node_parser - self.index = index self.oss_cache = oss_cache + self.index = index if use_local_qa_model: # API暂不支持此选项 @@ -109,7 +109,7 @@ async def load(self, file_directory: str, enable_qa_extraction: bool): logger.info("[DataReader] Start inserting to index.") - self.index.insert_nodes(nodes) + await self.index.insert_nodes_async(nodes) self.index.storage_context.persist(persist_dir=store_path.persist_path) logger.info(f"Inserted {len(nodes)} nodes successfully.") return diff --git a/src/pai_rag/modules/embedding/embedding.py b/src/pai_rag/modules/embedding/embedding.py index 8b8350aa..aa2ea57e 100644 --- a/src/pai_rag/modules/embedding/embedding.py +++ b/src/pai_rag/modules/embedding/embedding.py @@ -3,9 +3,9 @@ from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding from llama_index.embeddings.dashscope import DashScopeEmbedding -from llama_index.embeddings.huggingface import HuggingFaceEmbedding from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG +from pai_rag.modules.embedding.my_huggingface_embedding import MyHuggingFaceEmbedding from pai_rag.utils.constants import DEFAULT_MODEL_DIR import os import logging @@ -52,7 +52,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): model_name = config.get("model_name", DEFAULT_HUGGINGFACE_EMBEDDING_MODEL) model_path = os.path.join(model_dir, model_name) - embed_model = HuggingFaceEmbedding( + embed_model = MyHuggingFaceEmbedding( model_name=model_path, embed_batch_size=embed_batch_size ) diff --git a/src/pai_rag/modules/embedding/my_huggingface_embedding.py b/src/pai_rag/modules/embedding/my_huggingface_embedding.py new file mode 100644 index 00000000..f21f8cbb --- /dev/null +++ b/src/pai_rag/modules/embedding/my_huggingface_embedding.py @@ -0,0 +1,143 @@ +import logging +from typing import Any, List, Optional + +from llama_index.core.base.embeddings.base import ( + DEFAULT_EMBED_BATCH_SIZE, + BaseEmbedding, +) +from llama_index.core.bridge.pydantic import Field, PrivateAttr +from llama_index.core.callbacks import CallbackManager +from llama_index.core.utils import get_cache_dir, infer_torch_device +from llama_index.embeddings.huggingface.utils import ( + DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, + get_query_instruct_for_model_name, + get_text_instruct_for_model_name, +) +from sentence_transformers import SentenceTransformer + +DEFAULT_HUGGINGFACE_LENGTH = 512 +logger = logging.getLogger(__name__) + + +# Add async interfaces +class MyHuggingFaceEmbedding(BaseEmbedding): + max_length: int = Field( + default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0 + ) + normalize: bool = Field(default=True, description="Normalize embeddings or not.") + query_instruction: Optional[str] = Field( + description="Instruction to prepend to query text." + ) + text_instruction: Optional[str] = Field( + description="Instruction to prepend to text." + ) + cache_folder: Optional[str] = Field( + description="Cache folder for Hugging Face files." + ) + + _model: Any = PrivateAttr() + _device: str = PrivateAttr() + + def __init__( + self, + model_name: str = DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, + tokenizer_name: Optional[str] = "deprecated", + pooling: str = "deprecated", + max_length: Optional[int] = None, + query_instruction: Optional[str] = None, + text_instruction: Optional[str] = None, + normalize: bool = True, + model: Optional[Any] = "deprecated", + tokenizer: Optional[Any] = "deprecated", + embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + cache_folder: Optional[str] = None, + trust_remote_code: bool = False, + device: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + **model_kwargs, + ): + self._device = device or infer_torch_device() + + cache_folder = cache_folder or get_cache_dir() + + for variable, value in [ + ("model", model), + ("tokenizer", tokenizer), + ("pooling", pooling), + ("tokenizer_name", tokenizer_name), + ]: + if value != "deprecated": + raise ValueError( + f"{variable} is deprecated. Please remove it from the arguments." + ) + if model_name is None: + raise ValueError("The `model_name` argument must be provided.") + + self._model = SentenceTransformer( + model_name, + device=self._device, + cache_folder=cache_folder, + trust_remote_code=trust_remote_code, + prompts={ + "query": query_instruction + or get_query_instruct_for_model_name(model_name), + "text": text_instruction + or get_text_instruct_for_model_name(model_name), + }, + **model_kwargs, + ) + if max_length: + self._model.max_seq_length = max_length + else: + max_length = self._model.max_seq_length + + super().__init__( + embed_batch_size=embed_batch_size, + callback_manager=callback_manager, + model_name=model_name, + max_length=max_length, + normalize=normalize, + query_instruction=query_instruction, + text_instruction=text_instruction, + ) + + @classmethod + def class_name(cls) -> str: + return "HuggingFaceEmbedding" + + def _embed( + self, + sentences: List[str], + prompt_name: Optional[str] = None, + ) -> List[List[float]]: + """Embed sentences.""" + return self._model.encode( + sentences, + batch_size=self.embed_batch_size, + prompt_name=prompt_name, + normalize_embeddings=self.normalize, + ).tolist() + + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + return self._embed(query, prompt_name="query") + + async def _aget_query_embedding(self, query: str) -> List[float]: + """Get query embedding async.""" + return self._get_query_embedding(query) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Get text embedding async.""" + return self._get_text_embedding(text) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + return self._embed(text, prompt_name="text") + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + return self._embed(texts, prompt_name="text") + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings.""" + return self._embed(texts, prompt_name="text") diff --git a/src/pai_rag/modules/index/index.py b/src/pai_rag/modules/index/index.py index 16498956..3ce7fdbf 100644 --- a/src/pai_rag/modules/index/index.py +++ b/src/pai_rag/modules/index/index.py @@ -68,6 +68,11 @@ def create_indices(self): def load_indices(self): if isinstance(self.storage_context.vector_store, FaissVectorStore): vector_index = load_index_from_storage(storage_context=self.storage_context) + return MyVectorStoreIndex( + nodes=list(vector_index.docstore.docs.values()), + storage_context=self.storage_context, + embed_model=self.embed_model, + ) else: vector_index = MyVectorStoreIndex( nodes=[], diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py index 54e68cb3..60362288 100644 --- a/src/pai_rag/modules/index/my_vector_store_index.py +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -4,15 +4,13 @@ """ +import asyncio import logging from typing import Any, Sequence -from queue import Queue, Empty -import threading from llama_index.core import VectorStoreIndex from llama_index.core.data_structs.data_structs import IndexDict from llama_index.core.schema import ( BaseNode, - ImageNode, IndexNode, ) from llama_index.core.utils import iter_batch @@ -20,7 +18,37 @@ logger = logging.getLogger(__name__) +def call_async(coro): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + else: + return loop.run_until_complete(coro) + + class MyVectorStoreIndex(VectorStoreIndex): + async def _postprocess_batch( + self, + index_struct: IndexDict, + nodes_batch: Sequence[BaseNode], + **insert_kwargs: Any, + ): + new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(nodes_batch, new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + async def _async_add_nodes_to_index( self, index_struct: IndexDict, @@ -32,112 +60,43 @@ async def _async_add_nodes_to_index( if not nodes: return + batch_process_coroutines = [] + for nodes_batch in iter_batch(nodes, self._insert_batch_size): nodes_batch = await self._aget_node_with_embedding( nodes_batch, show_progress ) - new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) - - # if the vector store doesn't store text, we need to add the nodes to the - # index struct and document store - if not self._vector_store.stores_text or self._store_nodes_override: - for node, new_id in zip(nodes_batch, new_ids): - # NOTE: remove embedding from node to avoid duplication - node_without_embedding = node.copy() - node_without_embedding.embedding = None - - index_struct.add_node(node_without_embedding, text_id=new_id) - self._docstore.add_documents( - [node_without_embedding], allow_update=True - ) - else: - # NOTE: if the vector store keeps text, - # we only need to add image and index nodes - for node, new_id in zip(nodes_batch, new_ids): - if isinstance(node, (ImageNode, IndexNode)): - # NOTE: remove embedding from node to avoid duplication - node_without_embedding = node.copy() - node_without_embedding.embedding = None - - index_struct.add_node(node_without_embedding, text_id=new_id) - self._docstore.add_documents( - [node_without_embedding], allow_update=True - ) - - def _add_nodes_batch_to_index( - self, - q: Queue, - index_struct, - insert_kwargs, - ): - i = 0 - while True: - try: - nodes_batch = q.get(timeout=3) - except Empty: - continue - - if nodes_batch is None: - q.task_done() - return - - i += 1 - print(f"Consuming batch {i}, batch size {len(nodes_batch)}") - new_ids = self._vector_store.add(nodes_batch, **insert_kwargs) - - if not self._vector_store.stores_text or self._store_nodes_override: - print("saving to docstore!!!") - # NOTE: if the vector store doesn't store text, - # we need to add the nodes to the index struct and document store - for node, new_id in zip(nodes_batch, new_ids): - # NOTE: remove embedding from node to avoid duplication - node_without_embedding = node.copy() - node_without_embedding.embedding = None - - index_struct.add_node(node_without_embedding, text_id=new_id) - self._docstore.add_documents( - [node_without_embedding], allow_update=True - ) - else: - print("Skipping saving to docstore!!!") - # NOTE: if the vector store keeps text, - # we only need to add image and index nodes - for node, new_id in zip(nodes_batch, new_ids): - if isinstance(node, (ImageNode, IndexNode)): - # NOTE: remove embedding from node to avoid duplication - node_without_embedding = node.copy() - node_without_embedding.embedding = None - - index_struct.add_node(node_without_embedding, text_id=new_id) - self._docstore.add_documents( - [node_without_embedding], allow_update=True - ) - q.task_done() - - def _add_nodes_to_index( - self, - index_struct: IndexDict, - nodes: Sequence[BaseNode], - show_progress: bool = False, - **insert_kwargs: Any, - ) -> None: - """Add document to index.""" - if not nodes: - return - q = Queue(maxsize=100) + batch_process_coroutines.append( + self._postprocess_batch(index_struct, nodes_batch, **insert_kwargs) + ) + await asyncio.gather(*batch_process_coroutines) - work_thread = threading.Thread( - target=self._add_nodes_batch_to_index, args=(q, index_struct, insert_kwargs) + async def _insert_async( + self, nodes: Sequence[BaseNode], **insert_kwargs: Any + ) -> None: + """Insert a document.""" + await self._async_add_nodes_to_index( + self._index_struct, nodes, show_progress=True, **insert_kwargs ) - work_thread.start() - - i = 0 - for nodes_batch in iter_batch(nodes, 500): - nodes_batch = self._get_node_with_embedding(nodes_batch, show_progress) - q.put(nodes_batch) - i += 1 - print(f"produced batch {i}, batch size {len(nodes_batch)}") - q.put(None) - q.join() + async def insert_nodes_async( + self, nodes: Sequence[BaseNode], **insert_kwargs: Any + ) -> None: + """Insert nodes. + + NOTE: overrides BaseIndex.insert_nodes. + VectorStoreIndex only stores nodes in document store + if vector store does not store text + """ + for node in nodes: + if isinstance(node, IndexNode): + try: + node.dict() + except ValueError: + self._object_map[node.index_id] = node.obj + node.obj = None + + with self._callback_manager.as_trace("insert_nodes"): + await self._insert_async(nodes, **insert_kwargs) + self._storage_context.index_store.add_index_struct(self._index_struct) diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py index d076dcf4..08bb6a02 100644 --- a/src/pai_rag/modules/retriever/my_elasticsearch_store.py +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -247,6 +247,7 @@ def __init__( vector_field=vector_field, metadata_mappings=metadata_mappings, ) + self._store._create_index_if_not_exists() super().__init__( index_name=index_name, From b3b12d98e19f6c66d9e174158d288c5acc3a3874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 17:08:08 +0800 Subject: [PATCH 14/17] Fix index conflict --- src/pai_rag/modules/index/store.py | 1 + src/pai_rag/modules/retriever/my_elasticsearch_store.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/pai_rag/modules/index/store.py b/src/pai_rag/modules/index/store.py index 09702055..4d423f5b 100644 --- a/src/pai_rag/modules/index/store.py +++ b/src/pai_rag/modules/index/store.py @@ -132,6 +132,7 @@ def _get_or_create_es(self): es_url=es_config["es_url"], es_user=es_config["es_user"], es_password=es_config["es_password"], + embedding_dimension=self.embed_dims, retrieval_strategy=AsyncDenseVectorStrategy( hybrid=True, rrf={"window_size": 50} ), diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py index 08bb6a02..3aff8a51 100644 --- a/src/pai_rag/modules/retriever/my_elasticsearch_store.py +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -212,6 +212,7 @@ def __init__( es_password: Optional[str] = None, text_field: str = "content", vector_field: str = "embedding", + embedding_dimension: int = 1536, batch_size: int = 200, distance_strategy: Optional[DISTANCE_STRATEGIES] = "COSINE", retrieval_strategy: Optional[AsyncRetrievalStrategy] = None, @@ -246,8 +247,11 @@ def __init__( text_field=text_field, vector_field=vector_field, metadata_mappings=metadata_mappings, + num_dimensions=embedding_dimension, + ) + asyncio.get_event_loop().run_until_complete( + self._store._create_index_if_not_exists() ) - self._store._create_index_if_not_exists() super().__init__( index_name=index_name, From fa74763b1867da0595e6053766111733ae94e804 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Wed, 12 Jun 2024 21:36:02 +0800 Subject: [PATCH 15/17] Add change index parameter for FAISS --- src/pai_rag/app/api/models.py | 10 +- src/pai_rag/app/api/query.py | 2 +- src/pai_rag/core/rag_application.py | 11 +- .../modules/chat/chat_engine_factory.py | 34 ++++-- .../modules/chat/llm_chat_engine_factory.py | 31 +++-- src/pai_rag/modules/index/index.py | 51 ++++---- src/pai_rag/modules/index/index_utils.py | 109 ++++++++++++++++++ .../modules/index/my_vector_store_index.py | 61 ++++++---- src/pai_rag/modules/retriever/retriever.py | 19 +-- src/pai_rag/utils/store_utils.py | 4 +- src/pai_rag/utils/tokenizer.py | 16 +++ 11 files changed, 246 insertions(+), 102 deletions(-) create mode 100644 src/pai_rag/modules/index/index_utils.py create mode 100644 src/pai_rag/utils/tokenizer.py diff --git a/src/pai_rag/app/api/models.py b/src/pai_rag/app/api/models.py index 8cfede1e..62faf36e 100644 --- a/src/pai_rag/app/api/models.py +++ b/src/pai_rag/app/api/models.py @@ -2,13 +2,16 @@ from typing import List, Dict +class VectorDbConfig(BaseModel): + faiss_path: str | None = None + + class RagQuery(BaseModel): question: str temperature: float | None = 0.1 - vector_topk: int | None = 3 - score_threshold: float | None = 0.5 chat_history: List[Dict[str, str]] | None = None session_id: str | None = None + vector_db: VectorDbConfig | None = None class LlmQuery(BaseModel): @@ -20,8 +23,7 @@ class LlmQuery(BaseModel): class RetrievalQuery(BaseModel): question: str - topk: int | None = 3 - score_threshold: float | None = 0.5 + vector_db: VectorDbConfig | None = None class RagResponse(BaseModel): diff --git a/src/pai_rag/app/api/query.py b/src/pai_rag/app/api/query.py index 60bd0d54..1bda1ccb 100644 --- a/src/pai_rag/app/api/query.py +++ b/src/pai_rag/app/api/query.py @@ -44,7 +44,7 @@ async def load_data(input: DataInput): await rag_service.add_knowledge( file_dir=input.file_path, enable_qa_extraction=input.enable_qa_extraction ) - return {"msg": "Update RAG configuration successfully."} + return {"msg": "Upload data successfully."} @router.post("/evaluate/response") diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 2fc2f622..63eb45bf 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -22,7 +22,6 @@ def uuid_generator() -> str: class RagApplication: def __init__(self): self.name = "RagApplication" - logging.basicConfig(level=logging.INFO) # 将日志级别设置为INFO self.logger = logging.getLogger(__name__) def initialize(self, config): @@ -80,8 +79,14 @@ async def aquery(self, query: RagQuery) -> RagResponse: answer="Empty query. Please input your question.", session_id=session_id ) + sessioned_config = self.config + if query.vector_db and query.vector_db.faiss_path: + sessioned_config = self.config.copy() + sessioned_config.index.update({"persist_path": query.vector_db.faiss_path}) + print(sessioned_config) + chat_engine_factory = module_registry.get_module_with_config( - "ChatEngineFactoryModule", self.config + "ChatEngineFactoryModule", sessioned_config ) query_chat_engine = chat_engine_factory.get_chat_engine( session_id, query.chat_history @@ -89,7 +94,7 @@ async def aquery(self, query: RagQuery) -> RagResponse: response = await query_chat_engine.achat(query.question) chat_store = module_registry.get_module_with_config( - "ChatStoreModule", self.config + "ChatStoreModule", sessioned_config ) chat_store.persist() return RagResponse(answer=response.response, session_id=session_id) diff --git a/src/pai_rag/modules/chat/chat_engine_factory.py b/src/pai_rag/modules/chat/chat_engine_factory.py index ea8eea29..69282019 100644 --- a/src/pai_rag/modules/chat/chat_engine_factory.py +++ b/src/pai_rag/modules/chat/chat_engine_factory.py @@ -18,17 +18,11 @@ logger = logging.getLogger(__name__) -class ChatEngineFactoryModule(ConfigurableModule): - @staticmethod - def get_dependencies() -> List[str]: - return ["QueryEngineModule", "ChatStoreModule"] - - def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.query_engine = new_params["QueryEngineModule"] - self.chat_store = new_params["ChatStoreModule"] - - return self +class ChatEngineFactory: + def __init__(self, chat_type, query_engine, chat_store): + self.chat_type = chat_type + self.query_engine = query_engine + self.chat_store = chat_store def get_chat_engine(self, session_id, chat_history): chat_memory = self.chat_store.get_chat_memory_buffer(session_id) @@ -36,7 +30,8 @@ def get_chat_engine(self, session_id, chat_history): history_messages = parse_chat_messages(chat_history) for hist_mes in history_messages: chat_memory.put(hist_mes) - if self.config.type == "CondenseQuestionChatEngine": + + if self.chat_type == "CondenseQuestionChatEngine": my_chat_engine = CondenseQuestionChatEngine.from_defaults( query_engine=self.query_engine, condense_question_prompt=CONDENSE_QUESTION_CHAT_ENGINE_PROMPT_ZH, @@ -47,3 +42,18 @@ def get_chat_engine(self, session_id, chat_history): return my_chat_engine else: raise ValueError(f"Unknown chat_engine_type: {self.config.type}") + + +class ChatEngineFactoryModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return ["QueryEngineModule", "ChatStoreModule"] + + def _create_new_instance(self, new_params: Dict[str, Any]): + config = new_params[MODULE_PARAM_CONFIG] + query_engine = new_params["QueryEngineModule"] + chat_store = new_params["ChatStoreModule"] + + return ChatEngineFactory( + config.type, query_engine=query_engine, chat_store=chat_store + ) diff --git a/src/pai_rag/modules/chat/llm_chat_engine_factory.py b/src/pai_rag/modules/chat/llm_chat_engine_factory.py index 815d0d3e..0b90cb67 100644 --- a/src/pai_rag/modules/chat/llm_chat_engine_factory.py +++ b/src/pai_rag/modules/chat/llm_chat_engine_factory.py @@ -12,16 +12,11 @@ logger = logging.getLogger(__name__) -class LlmChatEngineFactoryModule(ConfigurableModule): - @staticmethod - def get_dependencies() -> List[str]: - return ["LlmModule", "ChatStoreModule"] - - def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.llm = new_params["LlmModule"] - self.chat_store = new_params["ChatStoreModule"] - return self +class LlmChatEngineFactory: + def __init__(self, chat_type, llm, chat_store): + self.chat_type = chat_type + self.llm = llm + self.chat_store = chat_store def get_chat_engine(self, session_id, chat_history): chat_memory = self.chat_store.get_chat_memory_buffer(session_id) @@ -30,14 +25,26 @@ def get_chat_engine(self, session_id, chat_history): for hist_mes in history_messages: chat_memory.put(hist_mes) - if self.config.type == "SimpleChatEngine": + if self.chat_type == "SimpleChatEngine": my_chat_engine = SimpleChatEngine.from_defaults( llm=self.llm, memory=chat_memory, verbose=True, ) logger.info("simple chat_engine instance created") + + return my_chat_engine else: raise ValueError(f"Unknown chat_engine_type: {self.config.type}") - return my_chat_engine + +class LlmChatEngineFactoryModule(ConfigurableModule): + @staticmethod + def get_dependencies() -> List[str]: + return ["LlmModule", "ChatStoreModule"] + + def _create_new_instance(self, new_params: Dict[str, Any]): + config = new_params[MODULE_PARAM_CONFIG] + llm = new_params["LlmModule"] + chat_store = new_params["ChatStoreModule"] + return LlmChatEngineFactory(config.type, llm, chat_store) diff --git a/src/pai_rag/modules/index/index.py b/src/pai_rag/modules/index/index.py index 3ce7fdbf..84e22697 100644 --- a/src/pai_rag/modules/index/index.py +++ b/src/pai_rag/modules/index/index.py @@ -4,7 +4,7 @@ from typing import Dict, List, Any from pai_rag.modules.index.my_vector_store_index import MyVectorStoreIndex -from llama_index.core import load_index_from_storage +from pai_rag.modules.index.index_utils import load_index_from_storage from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG from pai_rag.modules.index.store import RagStore @@ -32,51 +32,44 @@ class IndexModule(ConfigurableModule): def get_dependencies() -> List[str]: return ["EmbeddingModule"] - def _get_embed_vec_dim(self): + def _get_embed_vec_dim(self, embed_model): # Get dimension size of embedding vector - return len(self.embed_model._get_text_embedding("test")) + return len(embed_model._get_text_embedding("test")) def _create_new_instance(self, new_params: Dict[str, Any]): - self.config = new_params[MODULE_PARAM_CONFIG] - self.embed_model = new_params["EmbeddingModule"] - self.embed_dims = self._get_embed_vec_dim() - persist_path = self.config.get("persist_path", DEFAULT_PERSIST_DIR) - folder_name = get_store_persist_directory_name(self.config, self.embed_dims) + config = new_params[MODULE_PARAM_CONFIG] + embed_model = new_params["EmbeddingModule"] + embed_dims = self._get_embed_vec_dim(embed_model) + persist_path = config.get("persist_path", DEFAULT_PERSIST_DIR) + folder_name = get_store_persist_directory_name(config, embed_dims) store_path.persist_path = os.path.join(persist_path, folder_name) + is_empty = not os.path.exists(store_path.persist_path) + rag_store = RagStore(config, store_path.persist_path, is_empty, embed_dims) + storage_context = rag_store.get_storage_context() - self.is_empty = not os.path.exists(store_path.persist_path) - rag_store = RagStore( - self.config, store_path.persist_path, self.is_empty, self.embed_dims - ) - self.storage_context = rag_store.get_storage_context() - - if self.is_empty: - return self.create_indices() + if is_empty: + return self.create_indices(storage_context, embed_model) else: - return self.load_indices() + return self.load_indices(storage_context, embed_model) - def create_indices(self): + def create_indices(self, storage_context, embed_model): logging.info("Empty index, need to create indices.") vector_index = MyVectorStoreIndex( - nodes=[], storage_context=self.storage_context, embed_model=self.embed_model + nodes=[], storage_context=storage_context, embed_model=embed_model ) logging.info("Created vector_index.") return vector_index - def load_indices(self): - if isinstance(self.storage_context.vector_store, FaissVectorStore): - vector_index = load_index_from_storage(storage_context=self.storage_context) - return MyVectorStoreIndex( - nodes=list(vector_index.docstore.docs.values()), - storage_context=self.storage_context, - embed_model=self.embed_model, - ) + def load_indices(self, storage_context, embed_model): + if isinstance(storage_context.vector_store, FaissVectorStore): + vector_index = load_index_from_storage(storage_context=storage_context) + return vector_index else: vector_index = MyVectorStoreIndex( nodes=[], - storage_context=self.storage_context, - embed_model=self.embed_model, + storage_context=storage_context, + embed_model=embed_model, ) return vector_index diff --git a/src/pai_rag/modules/index/index_utils.py b/src/pai_rag/modules/index/index_utils.py new file mode 100644 index 00000000..8f1ff4e5 --- /dev/null +++ b/src/pai_rag/modules/index/index_utils.py @@ -0,0 +1,109 @@ +import logging +from typing import Any, List, Optional, Sequence + +from llama_index.core.indices.base import BaseIndex +from llama_index.core.indices.composability.graph import ComposableGraph +from llama_index.core.indices.registry import INDEX_STRUCT_TYPE_TO_INDEX_CLASS +from llama_index.core.storage.storage_context import StorageContext +from llama_index.core.data_structs.struct_type import IndexStructType + +from pai_rag.modules.index.my_vector_store_index import MyVectorStoreIndex + + +MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS = INDEX_STRUCT_TYPE_TO_INDEX_CLASS +MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS[ + IndexStructType.VECTOR_STORE +] = MyVectorStoreIndex + +logger = logging.getLogger(__name__) + + +def load_index_from_storage( + storage_context: StorageContext, + index_id: Optional[str] = None, + **kwargs: Any, +) -> BaseIndex: + """Load index from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + index_id (Optional[str]): ID of the index to load. + Defaults to None, which assumes there's only a single index + in the index store and load it. + **kwargs: Additional keyword args to pass to the index constructors. + """ + index_ids: Optional[Sequence[str]] + if index_id is None: + index_ids = None + else: + index_ids = [index_id] + + indices = load_indices_from_storage(storage_context, index_ids=index_ids, **kwargs) + + if len(indices) == 0: + raise ValueError( + "No index in storage context, check if you specified the right persist_dir." + ) + elif len(indices) > 1: + raise ValueError( + f"Expected to load a single index, but got {len(indices)} instead. " + "Please specify index_id." + ) + + return indices[0] + + +def load_indices_from_storage( + storage_context: StorageContext, + index_ids: Optional[Sequence[str]] = None, + **kwargs: Any, +) -> List[BaseIndex]: + """Load multiple indices from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + index_id (Optional[Sequence[str]]): IDs of the indices to load. + Defaults to None, which loads all indices in the index store. + **kwargs: Additional keyword args to pass to the index constructors. + """ + if index_ids is None: + logger.info("Loading all indices.") + index_structs = storage_context.index_store.index_structs() + else: + logger.info(f"Loading indices with ids: {index_ids}") + index_structs = [] + for index_id in index_ids: + index_struct = storage_context.index_store.get_index_struct(index_id) + if index_struct is None: + raise ValueError(f"Failed to load index with ID {index_id}") + index_structs.append(index_struct) + + indices = [] + for index_struct in index_structs: + type_ = index_struct.get_type() + index_cls = MODIFIED_INDEX_STRUCT_TYPE_TO_INDEX_CLASS[type_] + index = index_cls( + index_struct=index_struct, storage_context=storage_context, **kwargs + ) + indices.append(index) + return indices + + +def load_graph_from_storage( + storage_context: StorageContext, + root_id: str, + **kwargs: Any, +) -> ComposableGraph: + """Load composable graph from storage context. + + Args: + storage_context (StorageContext): storage context containing + docstore, index store and vector store. + root_id (str): ID of the root index of the graph. + **kwargs: Additional keyword args to pass to the index constructors. + """ + indices = load_indices_from_storage(storage_context, index_ids=None, **kwargs) + all_indices = {index.index_id: index for index in indices} + return ComposableGraph(all_indices=all_indices, root_id=root_id) diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py index 60362288..37350852 100644 --- a/src/pai_rag/modules/index/my_vector_store_index.py +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -28,26 +28,44 @@ def call_async(coro): class MyVectorStoreIndex(VectorStoreIndex): - async def _postprocess_batch( + async def _process_one_batch( self, + nodes_batch: Sequence[Sequence[BaseNode]], index_struct: IndexDict, - nodes_batch: Sequence[BaseNode], + semaphore: asyncio.Semaphore, **insert_kwargs: Any, ): - new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) - - # if the vector store doesn't store text, we need to add the nodes to the - # index struct and document store - if not self._vector_store.stores_text or self._store_nodes_override: - for node, new_id in zip(nodes_batch, new_ids): - # NOTE: remove embedding from node to avoid duplication - node_without_embedding = node.copy() - node_without_embedding.embedding = None - - index_struct.add_node(node_without_embedding, text_id=new_id) - self._docstore.add_documents( - [node_without_embedding], allow_update=True + async with semaphore: + new_ids = await self._vector_store.async_add(nodes_batch, **insert_kwargs) + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(nodes_batch, new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + + async def _postprocess_all_batch( + self, + nodes_batch_list: Sequence[Sequence[BaseNode]], + index_struct: IndexDict, + **insert_kwargs: Any, + ): + asyncio_semaphore = asyncio.Semaphore(10) + batch_process_coroutines = [] + for nodes_batch in nodes_batch_list: + batch_process_coroutines.append( + self._process_one_batch( + nodes_batch, index_struct, asyncio_semaphore, **insert_kwargs ) + ) + await asyncio.gather(*batch_process_coroutines) async def _async_add_nodes_to_index( self, @@ -60,17 +78,16 @@ async def _async_add_nodes_to_index( if not nodes: return - batch_process_coroutines = [] - - for nodes_batch in iter_batch(nodes, self._insert_batch_size): + node_batch_list = [] + for nodes_batch in iter_batch(nodes, 500): nodes_batch = await self._aget_node_with_embedding( nodes_batch, show_progress ) + node_batch_list.append(nodes_batch) - batch_process_coroutines.append( - self._postprocess_batch(index_struct, nodes_batch, **insert_kwargs) - ) - await asyncio.gather(*batch_process_coroutines) + await self._postprocess_all_batch( + node_batch_list, index_struct, **insert_kwargs + ) async def _insert_async( self, nodes: Sequence[BaseNode], **insert_kwargs: Any diff --git a/src/pai_rag/modules/retriever/retriever.py b/src/pai_rag/modules/retriever/retriever.py index 77f3c153..d65b4a33 100644 --- a/src/pai_rag/modules/retriever/retriever.py +++ b/src/pai_rag/modules/retriever/retriever.py @@ -3,8 +3,6 @@ import logging from typing import Dict, List, Any -import jieba -from nltk.corpus import stopwords from llama_index.core.indices.list.base import SummaryIndex from llama_index.core.retrievers import QueryFusionRetriever from llama_index.core.tools import RetrieverTool @@ -12,7 +10,7 @@ from llama_index.core.retrievers import RouterRetriever from llama_index.core.vector_stores.types import VectorStoreQueryMode -# from llama_index.retrievers.bm25 import BM25Retriever +from pai_rag.utils.tokenizer import jieba_tokenizer from pai_rag.integrations.retrievers.bm25 import BM25Retriever from pai_rag.modules.base.configurable_module import ConfigurableModule from pai_rag.modules.base.module_constants import MODULE_PARAM_CONFIG @@ -22,19 +20,6 @@ logger = logging.getLogger(__name__) -stopword_list = stopwords.words("chinese") + stopwords.words("english") - - -## PUT in utils file and add stopword in TRIE structure. -def jieba_tokenize(text: str) -> List[str]: - tokens = [] - for w in jieba.lcut(text): - token = w.lower() - if token not in stopword_list: - tokens.append(token) - - return tokens - class RetrieverModule(ConfigurableModule): @staticmethod @@ -71,7 +56,7 @@ def _create_new_instance(self, new_params: Dict[str, Any]): bm25_retriever = BM25Retriever.from_defaults( index=vector_index, similarity_top_k=similarity_top_k, - tokenizer=jieba_tokenize, + tokenizer=jieba_tokenizer, ) if retrieval_mode == "embedding": diff --git a/src/pai_rag/utils/store_utils.py b/src/pai_rag/utils/store_utils.py index a42bb42c..0873375c 100644 --- a/src/pai_rag/utils/store_utils.py +++ b/src/pai_rag/utils/store_utils.py @@ -14,9 +14,9 @@ def get_store_persist_directory_name(storage_config, ndims): raw_text = "sample_store_key" vector_store_type = storage_config["vector_store"]["type"].lower() if vector_store_type == "chroma": - raw_text = json.dumps(storage_config["vector_store"]) + raw_text = json.dumps(storage_config["vector_store"], sort_keys=True) elif vector_store_type == "faiss": - raw_text = json.dumps(storage_config["vector_store"]) + raw_text = {"type": "faiss"} elif vector_store_type == "hologres": keywords = ["host", "port", "database", "table_name"] json_data = {k: storage_config["vector_store"][k] for k in keywords} diff --git a/src/pai_rag/utils/tokenizer.py b/src/pai_rag/utils/tokenizer.py new file mode 100644 index 00000000..168702b9 --- /dev/null +++ b/src/pai_rag/utils/tokenizer.py @@ -0,0 +1,16 @@ +import jieba +from nltk.corpus import stopwords +from typing import List + +stopword_list = stopwords.words("chinese") + stopwords.words("english") + + +## PUT in utils file and add stopword in TRIE structure. +def jieba_tokenizer(text: str) -> List[str]: + tokens = [] + for w in jieba.lcut(text): + token = w.lower() + if token not in stopword_list: + tokens.append(token) + + return tokens From 305ceeb9d620a4c97e6ce0d058e3e6382685ebbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Thu, 13 Jun 2024 19:51:30 +0800 Subject: [PATCH 16/17] Fix batch size --- src/pai_rag/modules/index/my_vector_store_index.py | 2 +- src/pai_rag/modules/retriever/my_elasticsearch_store.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pai_rag/modules/index/my_vector_store_index.py b/src/pai_rag/modules/index/my_vector_store_index.py index 37350852..b4fd3f5a 100644 --- a/src/pai_rag/modules/index/my_vector_store_index.py +++ b/src/pai_rag/modules/index/my_vector_store_index.py @@ -79,7 +79,7 @@ async def _async_add_nodes_to_index( return node_batch_list = [] - for nodes_batch in iter_batch(nodes, 500): + for nodes_batch in iter_batch(nodes, 100): nodes_batch = await self._aget_node_with_embedding( nodes_batch, show_progress ) diff --git a/src/pai_rag/modules/retriever/my_elasticsearch_store.py b/src/pai_rag/modules/retriever/my_elasticsearch_store.py index 3aff8a51..63205ccc 100644 --- a/src/pai_rag/modules/retriever/my_elasticsearch_store.py +++ b/src/pai_rag/modules/retriever/my_elasticsearch_store.py @@ -331,6 +331,8 @@ async def async_add( if len(nodes) == 0: return [] + add_kwargs.update({"max_retries": 3}) + embeddings: List[List[float]] = [] texts: List[str] = [] metadatas: List[dict] = [] From 68263a15bfc1e1c0a24acd6c2cf660ec2df821de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Thu, 13 Jun 2024 20:35:18 +0800 Subject: [PATCH 17/17] Update --- src/pai_rag/core/rag_application.py | 2 +- src/pai_rag/core/rag_service.py | 6 +++++- tests/core/test_rag_application.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/pai_rag/core/rag_application.py b/src/pai_rag/core/rag_application.py index 63eb45bf..cc314882 100644 --- a/src/pai_rag/core/rag_application.py +++ b/src/pai_rag/core/rag_application.py @@ -38,7 +38,7 @@ async def load_knowledge(self, file_dir, enable_qa_extraction=False): data_loader = module_registry.get_module_with_config( "DataLoaderModule", self.config ) - await data_loader.load(file_dir, enable_qa_extraction) + await data_loader.aload(file_dir, enable_qa_extraction) async def aquery_retrieval(self, query: RetrievalQuery) -> RetrievalResponse: if not query.question: diff --git a/src/pai_rag/core/rag_service.py b/src/pai_rag/core/rag_service.py index 769f2ed6..4e768752 100644 --- a/src/pai_rag/core/rag_service.py +++ b/src/pai_rag/core/rag_service.py @@ -12,6 +12,9 @@ from pai_rag.app.web.view_model import view_model from openinference.instrumentation import using_attributes from typing import Any, Dict +import logging + +logger = logging.getLogger(__name__) def trace_correlation_id(function): @@ -55,7 +58,8 @@ async def add_knowledge_async( try: await self.rag.load_knowledge(file_dir, enable_qa_extraction) self.tasks_status[task_id] = "completed" - except Exception: + except Exception as ex: + logger.error(f"Upload failed: {ex}") self.tasks_status[task_id] = "failed" def get_task_status(self, task_id: str) -> str: diff --git a/tests/core/test_rag_application.py b/tests/core/test_rag_application.py index 83006385..ef27858f 100644 --- a/tests/core/test_rag_application.py +++ b/tests/core/test_rag_application.py @@ -28,9 +28,9 @@ def rag_app(): # Test load knowledge file -def test_add_knowledge_file(rag_app: RagApplication): +async def test_add_knowledge_file(rag_app: RagApplication): data_dir = os.path.join(BASE_DIR, "tests/testdata/paul_graham") - rag_app.load_knowledge(data_dir) + await rag_app.load_knowledge(data_dir) # Test rag query