diff --git a/poetry.lock b/poetry.lock index f168d99d..9bfd357d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1174,13 +1174,13 @@ typing-inspect = ">=0.4.0,<1" [[package]] name = "datasets" -version = "2.19.1" +version = "2.19.2" description = "HuggingFace community-driven open-source library of datasets" optional = false python-versions = ">=3.8.0" files = [ - {file = "datasets-2.19.1-py3-none-any.whl", hash = "sha256:f7a78d15896f45004ccac1c298f3c7121f92f91f6f2bfbd4e4f210f827e6e411"}, - {file = "datasets-2.19.1.tar.gz", hash = "sha256:0df9ef6c5e9138cdb996a07385220109ff203c204245578b69cca905eb151d3a"}, + {file = "datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e"}, + {file = "datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef"}, ] [package.dependencies] @@ -1196,7 +1196,7 @@ pandas = "*" pyarrow = ">=12.0.0" pyarrow-hotfix = "*" pyyaml = ">=5.1" -requests = ">=2.19.0" +requests = ">=2.32.1" tqdm = ">=4.62.1" xxhash = "*" @@ -1204,7 +1204,7 @@ xxhash = "*" apache-beam = ["apache-beam (>=2.26.0)"] audio = ["librosa", "soundfile (>=0.12.1)"] benchmarks = ["tensorflow (==2.12.0)", "torch (==2.0.1)", "transformers (==4.30.1)"] -dev = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +dev = ["Pillow (>=9.4.0)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "ruff (>=0.3.0)", "s3fs", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] docs = ["s3fs", "tensorflow (>=2.6.0)", "torch", "transformers"] jax = ["jax (>=0.3.14)", "jaxlib (>=0.3.14)"] metrics-tests = ["Werkzeug (>=1.0.1)", "accelerate", "bert-score (>=0.3.6)", "jiwer", "langdetect", "mauve-text", "nltk", "requests-file (>=1.5.1)", "rouge-score", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "spacy (>=3.0.0)", "texttable (>=1.6.3)", "tldextract", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "typer (<0.5.0)"] @@ -1212,9 +1212,9 @@ quality = ["ruff (>=0.3.0)"] s3 = ["s3fs"] tensorflow = ["tensorflow (>=2.6.0)"] tensorflow-gpu = ["tensorflow (>=2.6.0)"] -tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] +tests = ["Pillow (>=9.4.0)", "absl-py", "apache-beam (>=2.26.0)", "elasticsearch (<8.0.0)", "faiss-cpu (>=1.6.4)", "jax (>=0.3.14)", "jaxlib (>=0.3.14)", "joblib (<1.3.0)", "joblibspark", "librosa", "lz4", "polars[timezone] (>=0.20.0)", "protobuf (<4.0.0)", "py7zr", "pyspark (>=3.4)", "pytest", "pytest-datadir", "pytest-xdist", "rarfile (>=4.0)", "s3fs (>=2021.11.1)", "soundfile (>=0.12.1)", "sqlalchemy", "tensorflow (>=2.6.0)", "tiktoken", "torch (>=2.0.0)", "transformers", "typing-extensions (>=4.6.1)", "zstandard"] torch = ["torch"] -vision = ["Pillow (>=6.2.1)"] +vision = ["Pillow (>=9.4.0)"] [[package]] name = "deprecated" @@ -3039,17 +3039,18 @@ llama-index-core = ">=0.10.1,<0.11.0" [[package]] name = "llama-index-vector-stores-elasticsearch" -version = "0.1.7" +version = "0.2.0" description = "llama-index vector_stores elasticsearch integration" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "llama_index_vector_stores_elasticsearch-0.1.7-py3-none-any.whl", hash = "sha256:99eda589e57cc1877b55686f03d810f53e91a6891872214972d1af4fac5dc440"}, - {file = "llama_index_vector_stores_elasticsearch-0.1.7.tar.gz", hash = "sha256:7e5435be500ee0d8852efefb4b7891693310a459350bb2aeaea6f0d5d3a23975"}, + {file = "llama_index_vector_stores_elasticsearch-0.2.0-py3-none-any.whl", hash = "sha256:098c98db48dfa513c7f4c33431251aeb0e432cd1617bb4186d23823c87f2a72c"}, + {file = "llama_index_vector_stores_elasticsearch-0.2.0.tar.gz", hash = "sha256:9897f7d195f08ee8752bce8e774519ab02a1897fe30dd5d40e53556f9b798186"}, ] [package.dependencies] -elasticsearch = ">=8.12.0,<9.0.0" +aiohttp = ">=3.9.5,<4.0.0" +elasticsearch = ">=8.13.1,<9.0.0" llama-index-core = ">=0.10.1,<0.11.0" [[package]] @@ -7838,4 +7839,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.12" -content-hash = "909a2eaea31206898ab29c207414a5ec6b03c07fd2b30c56c419f577c768a6f8" +content-hash = "5c891b42861023d77d87bf4027b2b3e53fe7eaad67d81c893f7f5be7e45a8e69" diff --git a/pyproject.toml b/pyproject.toml index 5bd81bf8..5398cebc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ llama-index-readers-database = "^0.1.3" llama-index-vector-stores-chroma = "^0.1.6" llama-index-vector-stores-faiss = "^0.1.2" llama-index-vector-stores-analyticdb = "^0.1.1" -llama-index-vector-stores-elasticsearch = "^0.1.7" +llama-index-vector-stores-elasticsearch = "^0.2.0" llama-index-vector-stores-milvus = "^0.1.10" gradio = "3.41.0" faiss-cpu = "^1.8.0" diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 66c0481b..5d2974ae 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -14,7 +14,6 @@ from llama_index.core import Settings from pai_rag.utils.constants import DEFAULT_EASYOCR_MODEL_DIR import json -import sys import unicodedata import logging import tempfile @@ -33,9 +32,10 @@ class PageItem(TypedDict): class PaiPDFReader(BaseReader): """Read PDF files including texts, tables, images. + Args: - enable_image_ocr (bool): whether load ocr model to process images - model_dir: (str): ocr model path + enable_image_ocr (bool): whether load ocr model to process images + model_dir: (str): ocr model path """ def __init__( @@ -54,21 +54,28 @@ def __init__( ) logger.info("finished loading ocr model") - """剪切图片 + def process_pdf_image(self, element: LTFigure, page_object: PageObject) -> str: """ + Processes an image element from a PDF, crops it out, and performs OCR on the result. + + Args: + element (LTFigure): An LTFigure object representing the image in the PDF, containing its coordinates. + page_object (PageObject): A PageObject representing the page in the PDF to be cropped. - def process_image(self, element: LTFigure, page_object: PageObject) -> str: - # 获取从PDF中裁剪图像的坐标 + Returns: + str: The OCR-processed text from the cropped image. + """ + # Retrieve the image's coordinates [image_left, image_top, image_right, image_bottom] = [ element.x0, element.y0, element.x1, element.y1, ] - # 使用坐标(left, bottom, right, top)裁剪页面 + # Adjust the page's media box to crop the image based on the coordinates page_object.mediabox.lower_left = (image_left, image_bottom) page_object.mediabox.upper_right = (image_right, image_top) - # 将裁剪后的页面保存为新的PDF + # Save the cropped page as a new PDF file and perform OCR cropped_pdf_writer = PyPDF2.PdfWriter() with tempfile.NamedTemporaryFile( delete=True, suffix=".pdf" @@ -76,12 +83,19 @@ def process_image(self, element: LTFigure, page_object: PageObject) -> str: cropped_pdf_writer.add_page(page_object) cropped_pdf_writer.write(cropped_pdf_file) cropped_pdf_file.flush() - return self.convert_to_images(cropped_pdf_file.name) + # Return the OCR-processed text + return self.ocr_pdf(cropped_pdf_file.name) - """创建一个将PDF内容转换为image的函数 - """ + def ocr_pdf(self, input_file: str) -> str: + """ + Function to convert PDF content into an image and then perform OCR (Optical Character Recognition) - def convert_to_images(self, input_file: str) -> str: + Args: + input_file (str): input file path. + + Returns: + str: text from ocr. + """ images = convert_from_path(input_file) image = images[0] with tempfile.NamedTemporaryFile( @@ -91,32 +105,34 @@ def convert_to_images(self, input_file: str) -> str: output_image_file.flush() return self.image_to_text(output_image_file.name) - """创建从图片中提取文本的函数 - """ - def image_to_text(self, image_path: str) -> str: - # 从图片中抽取文本 + """ + Function to perform OCR to extract text from image + + Args: + image_path (str): input image path. + + Returns: + str: text from ocr. + """ result = self.image_reader.readtext(image_path) predictions = "".join([item[1] for item in result]) return predictions - """从页面中提取表格内容 + """Function to extract content from table """ @staticmethod def extract_table(pdf: pdfplumber.PDF, page_num: int, table_num: int) -> List[Any]: - # 查找已检查的页面 table_page = pdf.pages[page_num] - # 提取适当的表格 table = table_page.extract_tables()[table_num] return table - """合并分页表格 + """Function to merge paginated tables """ @staticmethod def merge_page_tables(total_tables: List[PageItem]) -> List[PageItem]: - # 合并分页表格 i = len(total_tables) - 1 while i - 1 >= 0: table = total_tables[i] @@ -135,16 +151,14 @@ def merge_page_tables(total_tables: List[PageItem]) -> List[PageItem]: i -= 1 return total_tables - """将表格转换为适当的格式 + """Function to parse table """ @staticmethod def parse_table(table: List[List]) -> str: table_string = "" - # 遍历表格的每一行 for row_num in range(len(table)): row = table[row_num] - # 从warp的文字删除线路断路器 cleaned_row = [ item.replace("\n", " ") if item is not None and "\n" in item @@ -153,13 +167,11 @@ def parse_table(table: List[List]) -> str: else item for item in row ] - # 将表格转换为字符串,注意'|'、'\n' table_string += "|" + "|".join(cleaned_row) + "|" + "\n" - # 删除最后一个换行符 table_string = table_string.strip() return table_string - """为表格生成摘要 + """Function to summarize table """ @staticmethod @@ -167,18 +179,17 @@ def tables_summarize(table: List[List]) -> str: prompt_text = f"请为以下表格生成一个摘要: {table}" response = Settings.llm.complete( prompt_text, - max_tokens=200, # 调整为所需的摘要长度 - n=1, # 生成摘要的数量 + max_tokens=200, + n=1, ) summarized_text = response return summarized_text - """表格数据转化为json数据 + """Function to convert table data to json """ @staticmethod def table_to_json(table: List[List]) -> str: - # 提取表头 table_info = [] column_name = table[0] for row in range(1, len(table)): @@ -190,23 +201,32 @@ def table_to_json(table: List[List]) -> str: return json.dumps(table_info, ensure_ascii=False) - """创建一个文本提取函数 + """Function to process text in pdf """ @staticmethod def text_extraction(elements: List[LTTextBoxHorizontal]) -> List[str]: - # 找到每一行的坐标 + """ + Extracts text lines from a list of text boxes and handles line breaks under specific conditions. + + Args: + elements: A list of LTTextBoxHorizontal objects representing text boxes on a page. + + Returns: + A list containing the extracted text lines with line breaks removed as per defined conditions. + """ boxes, texts = [], [] - # 页面文字的开始和结束坐标 + # Initialize the start and end coordinates of the page text max_x1 = 0 - min_x0 = sys.maxsize + min_x0 = float("inf") for text_box_h in elements: if isinstance(text_box_h, LTTextBoxHorizontal): for text_box_h_l in text_box_h: if isinstance(text_box_h_l, LTTextLineHorizontal): + # Process each text line's coordinates and content x0, y0, x1, y1 = text_box_h_l.bbox text = text_box_h_l.get_text() - # 判断这一行是否以标点符号结尾。以标点符号结尾的行的结束位置和正常文字的结束位置不同 + # Check if the line ends with punctuation and requires special handling if not ( text[-1] == "\n" and len(text) >= 2 @@ -216,7 +236,7 @@ def text_extraction(elements: List[LTTextBoxHorizontal]) -> List[str]: min_x0 = min(min_x0, x0) texts.append(text) boxes.append((x0, x1)) - # 判断是否去除换行符的条件:该行的结尾坐标大于等于除标点符号结尾的行的坐标向下取整 且 下一行的开头坐标小于等于最小文字坐标取整+1 + # Remove line breaks based on defined conditions for cur in range(len(boxes) - 1): if boxes[cur][1] >= int(max_x1) and boxes[cur + 1][0] <= int(min_x0) + 1: texts[cur] = texts[cur].replace("\n", "") @@ -259,59 +279,58 @@ def load( # open PDF file pdfFileObj = open(file_path, "rb") - # 创建一个PDF阅读器对象 + # Create a PDF reader object pdf_read = PyPDF2.PdfReader(pdfFileObj) total_tables = [] page_items = [] - # 打开pdf文件 + # Open the PDF and extract pages pdf = pdfplumber.open(file_path) - # 从PDF中提取页面 for pagenum, page in enumerate(extract_pages(file_path)): - # 初始化从页面中提取文本所需的变量 + # Initialize variables for extracting text from the page page_object = pdf_read.pages[pagenum] text_elements = [] text_from_images = [] - # 初始化检查表的数量 + # Initialize table count table_num = 0 first_element = True - # 查找已检查的页面 + # Find the checked page page_tables = pdf.pages[pagenum] - # 找出本页上的表格数目 + # Find the number of tables on the page tables = page_tables.find_tables() - # 找到所有的元素 + # Find all elements on the page page_elements = [(element.y1, element) for element in page._objs] - # 对页面中出现的所有元素进行排序 + # Sort the elements on the page by their y1 coordinate page_elements.sort(key=lambda a: a[0], reverse=True) - # 查找组成页面的元素 + # Iterate through the page's elements for i, component in enumerate(page_elements): - # 提取页面布局的元素 + # Extract text elements element = component[1] - # 检查该元素是否为文本元素 + # Check if the element is a text box if isinstance(element, LTTextBoxHorizontal): text_elements.append(element) - # 检查元素中的图像 + # Check for images and extract text from them if OCR is enabled elif isinstance(element, LTFigure) and self.enable_image_ocr: - # 从PDF中提取文字 - image_texts = self.process_image(element, page_object) + # Extract text from the PDF image + image_texts = self.process_pdf_image(element, page_object) text_from_images.append(image_texts) - # 检查表的元素 + # Check for table elements elif isinstance(element, LTRect): - lower_side = sys.maxsize + lower_side = float("inf") upper_side = 0 - # 如果第一个矩形元素 + # If it's the first rectangle element if first_element is True and (table_num + 1) <= len(tables): - # 找到表格的边界框 + # Find the bounding box of the table lower_side = page.bbox[3] - tables[table_num].bbox[3] upper_side = element.y1 - # 从表中提取信息 - tabel_text = PaiPDFReader.extract_table(pdf, pagenum, table_num) + # Extract the table data + table_text = PaiPDFReader.extract_table(pdf, pagenum, table_num) item = PageItem( page_number=pagenum, @@ -319,23 +338,25 @@ def load( item_type="table", element=element, table_num=table_num, - text=tabel_text, + text=table_text, ) total_tables.append(item) - # 让它成为另一个元素 + # Move to the next element first_element = False - # 检查我们是否已经从页面中提取了表 + # Check if we've extracted a table from the page if element.y0 >= lower_side and element.y1 <= upper_side: pass - elif not isinstance(page_elements[i + 1][1], LTRect): + elif i + 1 < len(page_elements) and not isinstance( + page_elements[i + 1][1], LTRect + ): first_element = True table_num += 1 - # 文本处理 + # Text extraction from text elements text_from_texts = PaiPDFReader.text_extraction(text_elements) page_plain_text = "".join(text_from_texts) - # 图片处理 + # Image text extraction page_image_text = "".join(text_from_images) page_items.append( @@ -351,19 +372,19 @@ def load( ) ) - # 合并分页表格 + # Merge tables across pages total_tables = PaiPDFReader.merge_page_tables(total_tables) - # 构造返回数据 + # Construct the returned data docs = [] for pagenum, item in enumerate(page_items): page_tables_texts = [] page_tables_summaries = [] page_tables_json = [] for table in total_tables: - # 如果页面匹配 + # If the page number matches if pagenum == table["page_number"]: - # 将表信息转换为结构化字符串格式 + # Convert the table data to a structured string table_string = PaiPDFReader.parse_table(table["text"]) summarized_table_text = PaiPDFReader.tables_summarize(table["text"]) json_data = PaiPDFReader.table_to_json(table["text"]) @@ -376,7 +397,7 @@ def load( page_info_text = item[0]["text"] + item[1]["text"] + page_table_text - # if extra_info is not None, check if it is a dictionary + # if `extra_info` is not None, check if it is a dictionary if extra_info: if not isinstance(extra_info, dict): raise TypeError("extra_info must be a dictionary.") diff --git a/src/pai_rag/modules/index/store.py b/src/pai_rag/modules/index/store.py index 1d369d0e..1a290cdd 100644 --- a/src/pai_rag/modules/index/store.py +++ b/src/pai_rag/modules/index/store.py @@ -8,6 +8,7 @@ from llama_index.vector_stores.chroma import ChromaVectorStore from llama_index.vector_stores.elasticsearch import ElasticsearchStore from llama_index.vector_stores.milvus import MilvusVectorStore +from elasticsearch.helpers.vectorstore import AsyncDenseVectorStrategy from pai_rag.integrations.vector_stores.vector_stores_hologres.hologres import ( HologresVectorStore, @@ -128,6 +129,7 @@ def _get_or_create_es(self): es_url=es_config["es_url"], es_user=es_config["es_user"], es_password=es_config["es_password"], + retrieval_strategy=AsyncDenseVectorStrategy(hybrid=True), ) def _get_or_create_milvus(self): diff --git a/tests/data_readers/test_easyocr_pdf_reader.py b/tests/data_readers/test_easyocr_pdf_reader.py index 87e33cee..73724d63 100644 --- a/tests/data_readers/test_easyocr_pdf_reader.py +++ b/tests/data_readers/test_easyocr_pdf_reader.py @@ -1,28 +1,26 @@ -""" -class TestEasyOcrPdfReader(unittest.TestCase): - def setUp(self): - # load config - base_dir = Path(__file__).parent.parent.parent - config_file = os.path.join(base_dir, "src/pai_rag/config/settings.local.yaml") - config = RagConfiguration.from_file(config_file).get_value() - module_registry.init_modules(config) - reader_config = config["data_reader"] - self.directory_reader = SimpleDirectoryReader( - input_dir="data/pdf_data", - file_extractor={ - ".pdf": PaiPDFReader( - enable_image_ocr=reader_config.get("enable_image_ocr", False), - model_dir=reader_config.get("easyocr_model_dir", None), - ) - }, - ) +import os +from pathlib import Path +from pai_rag.core.rag_configuration import RagConfiguration +from pai_rag.modules.module_registry import module_registry +from pai_rag.integrations.readers.pai_pdf_reader import PaiPDFReader +from llama_index.core import SimpleDirectoryReader - def test_load_documents(self): - # load documents - self.documents = self.directory_reader.load_data() - self.assertEqual(len(self.documents), 40) +BASE_DIR = Path(__file__).parent.parent.parent -if __name__ == "__main__": - unittest.main() -""" +def test_pai_pdf_reader(): + config_file = os.path.join(BASE_DIR, "src/pai_rag/config/settings.toml") + config = RagConfiguration.from_file(config_file).get_value() + module_registry.init_modules(config) + reader_config = config["data_reader"] + directory_reader = SimpleDirectoryReader( + input_dir="tests/testdata/data/pdf_data", + file_extractor={ + ".pdf": PaiPDFReader( + enable_image_ocr=reader_config.get("enable_image_ocr", False), + model_dir=reader_config.get("easyocr_model_dir", None), + ) + }, + ) + documents = directory_reader.load_data() + assert len(documents) > 0 diff --git a/tests/testdata/data/pdf_data/pai_document.pdf b/tests/testdata/data/pdf_data/pai_document.pdf new file mode 100644 index 00000000..6dddbf80 Binary files /dev/null and b/tests/testdata/data/pdf_data/pai_document.pdf differ