From cf727465ce35e13cfc1961d369bdbc4b5ceddfc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Thu, 6 Jun 2024 20:53:23 +0800 Subject: [PATCH 1/2] Fix node hash for special file types --- src/pai_rag/data/rag_dataloader.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index b6df3d64..6fe0c258 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -1,3 +1,5 @@ +import os +from typing import Any, Dict from llama_index.core import Settings from llama_index.core.schema import TextNode from llama_index.llms.huggingface import HuggingFaceLLM @@ -14,6 +16,8 @@ DEFAULT_LOCAL_QA_MODEL_PATH = "/huggingface/transformers/qwen_1.8b" +DOC_TYPES_DO_NOT_NEED_CHUNKING = set([".csv", ".xlsx", ".md", ".xls", ".htm", ".html"]) + class RagDataLoader: """ @@ -49,13 +53,26 @@ def __init__( logger.info("RagDataLoader initialized.") + def _extract_file_type(self, metadata: Dict[str, Any]): + file_name = metadata.get("file_name", "dummy.txt") + return os.path.splitext(file_name)[1] + async def load(self, file_directory: str, enable_qa_extraction: bool): data_reader = self.datareader_factory.get_reader(file_directory) docs = data_reader.load_data() nodes = [] + + doc_cnt_map = {} for doc in docs: - if doc.metadata.get("file_type", "Unknown") == "HTML": - node_id = node_id_hash(0, doc) + doc_type = self._extract_file_type(doc.metadata) + + if doc_type in DOC_TYPES_DO_NOT_NEED_CHUNKING: + doc_key = f"""{doc.metadata.get("file_path", "dummy")}""" + print(doc_key) + if doc_key not in doc_cnt_map: + doc_cnt_map[doc_key] = 0 + doc_cnt_map[doc_key] += 1 + node_id = node_id_hash(doc_cnt_map[doc_key], doc) nodes.append( TextNode(id_=node_id, text=doc.text, metadata=doc.metadata) ) From 60606206c4c735812b9299dcb1de764413a775da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=B9=E8=B7=83?= Date: Thu, 6 Jun 2024 20:54:47 +0800 Subject: [PATCH 2/2] Remove print --- src/pai_rag/data/rag_dataloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pai_rag/data/rag_dataloader.py b/src/pai_rag/data/rag_dataloader.py index 6fe0c258..3dd0cd80 100644 --- a/src/pai_rag/data/rag_dataloader.py +++ b/src/pai_rag/data/rag_dataloader.py @@ -68,7 +68,6 @@ async def load(self, file_directory: str, enable_qa_extraction: bool): if doc_type in DOC_TYPES_DO_NOT_NEED_CHUNKING: doc_key = f"""{doc.metadata.get("file_path", "dummy")}""" - print(doc_key) if doc_key not in doc_cnt_map: doc_cnt_map[doc_key] = 0 doc_cnt_map[doc_key] += 1