diff --git a/src/pai_rag/app/web/tabs/upload_tab.py b/src/pai_rag/app/web/tabs/upload_tab.py index 7cc67194..178eb1e8 100644 --- a/src/pai_rag/app/web/tabs/upload_tab.py +++ b/src/pai_rag/app/web/tabs/upload_tab.py @@ -16,6 +16,7 @@ def upload_oss_knowledge( chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, upload_index, ): @@ -35,6 +36,7 @@ def upload_oss_knowledge( chunk_overlap=chunk_overlap, enable_raptor=enable_raptor, enable_multimodal=enable_multimodal, + enable_mandatory_ocr=enable_mandatory_ocr, enable_table_summary=enable_table_summary, index_name=upload_index, from_oss=True, @@ -48,6 +50,7 @@ def upload_files( chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, upload_index, ): @@ -67,6 +70,7 @@ def upload_files( chunk_overlap=chunk_overlap, enable_raptor=enable_raptor, enable_multimodal=enable_multimodal, + enable_mandatory_ocr=enable_mandatory_ocr, enable_table_summary=enable_table_summary, index_name=upload_index, ): @@ -80,6 +84,7 @@ def upload_knowledge( chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, index_name, from_oss: bool = False, @@ -89,6 +94,7 @@ def upload_knowledge( { "chunk_size": chunk_size, "chunk_overlap": chunk_overlap, + "enable_mandatory_ocr": enable_mandatory_ocr, "enable_table_summary": enable_table_summary, } ) @@ -188,6 +194,12 @@ def create_upload_tab() -> Dict[str, Any]: elem_id="enable_multimodal", visible=True, ) + enable_mandatory_ocr = gr.Checkbox( + label="Yes", + info="Process PDF with OCR", + elem_id="enable_mandatory_ocr", + visible=True, + ) enable_table_summary = gr.Checkbox( label="Yes", info="Process with Table Summary ", @@ -232,6 +244,7 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, upload_index, ], @@ -247,6 +260,7 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, upload_index, ], @@ -269,6 +283,7 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap, enable_raptor, enable_multimodal, + enable_mandatory_ocr, enable_table_summary, upload_index, ], @@ -287,5 +302,6 @@ def create_upload_tab() -> Dict[str, Any]: chunk_overlap.elem_id: chunk_overlap, enable_raptor.elem_id: enable_raptor, enable_multimodal.elem_id: enable_multimodal, + enable_mandatory_ocr.elem_id: enable_mandatory_ocr, enable_table_summary.elem_id: enable_table_summary, } diff --git a/src/pai_rag/app/web/view_model.py b/src/pai_rag/app/web/view_model.py index d23a58ee..c0bdec45 100644 --- a/src/pai_rag/app/web/view_model.py +++ b/src/pai_rag/app/web/view_model.py @@ -71,6 +71,7 @@ class ViewModel(BaseModel): # reader reader_type: str = "SimpleDirectoryReader" enable_raptor: bool = False + enable_mandatory_ocr: bool = False enable_table_summary: bool = False config_file: str = None @@ -174,6 +175,7 @@ def from_app_config(config: RagConfig): view_model.chunk_overlap = config.node_parser.chunk_overlap view_model.chunk_size = config.node_parser.chunk_size + view_model.enable_mandatory_ocr = config.data_reader.enable_mandatory_ocr view_model.enable_table_summary = config.data_reader.enable_table_summary view_model.similarity_top_k = config.retriever.similarity_top_k @@ -282,6 +284,7 @@ def to_app_config(self): config["node_parser"]["chunk_size"] = int(self.chunk_size) config["node_parser"]["chunk_overlap"] = int(self.chunk_overlap) + config["data_reader"]["enable_mandatory_ocr"] = self.enable_mandatory_ocr config["data_reader"]["enable_table_summary"] = self.enable_table_summary config["retriever"]["similarity_top_k"] = self.similarity_top_k @@ -506,6 +509,7 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]: settings["chunk_overlap"] = {"value": self.chunk_overlap} settings["enable_raptor"] = {"value": self.enable_raptor} settings["enable_multimodal"] = {"value": self.enable_multimodal} + settings["enable_mandatory_ocr"] = {"value": self.enable_mandatory_ocr} settings["enable_table_summary"] = {"value": self.enable_table_summary} # retrieval and rerank diff --git a/src/pai_rag/integrations/readers/pai/pai_data_reader.py b/src/pai_rag/integrations/readers/pai/pai_data_reader.py index f9a1e4eb..bb5223b6 100644 --- a/src/pai_rag/integrations/readers/pai/pai_data_reader.py +++ b/src/pai_rag/integrations/readers/pai/pai_data_reader.py @@ -22,6 +22,7 @@ class BaseDataReaderConfig(BaseModel): concat_csv_rows: bool = False + enable_mandatory_ocr: bool = False enable_table_summary: bool = False format_sheet_data_to_json: bool = False sheet_column_filters: List[str] | None = None @@ -45,6 +46,7 @@ def get_file_readers(reader_config: BaseDataReaderConfig = None, oss_store: Any oss_cache=oss_store, # Storing docx images ), ".pdf": PaiPDFReader( + enable_mandatory_ocr=reader_config.enable_mandatory_ocr, enable_table_summary=reader_config.enable_table_summary, oss_cache=oss_store, # Storing pdf images ), diff --git a/src/pai_rag/integrations/readers/pai_pdf_reader.py b/src/pai_rag/integrations/readers/pai_pdf_reader.py index 3dcd5069..42f55827 100644 --- a/src/pai_rag/integrations/readers/pai_pdf_reader.py +++ b/src/pai_rag/integrations/readers/pai_pdf_reader.py @@ -14,7 +14,6 @@ from magic_pdf.pipe.UNIPipe import UNIPipe from magic_pdf.pipe.OCRPipe import OCRPipe -from magic_pdf.pipe.TXTPipe import TXTPipe import magic_pdf.model as model_config from rapidocr_onnxruntime import RapidOCR from rapid_table import RapidTable @@ -52,14 +51,19 @@ class PaiPDFReader(BaseReader): def __init__( self, + enable_mandatory_ocr: bool = False, enable_table_summary: bool = False, oss_cache: Any = None, ) -> None: self.enable_table_summary = enable_table_summary + self.enable_mandatory_ocr = enable_mandatory_ocr self._oss_cache = oss_cache logger.info( f"PaiPdfReader created with enable_table_summary : {self.enable_table_summary}" ) + logger.info( + f"PaiPdfReader created with enable_mandatory_ocr : {self.enable_mandatory_ocr}" + ) def _transform_local_to_oss(self, pdf_name: str, local_url: str): image = Image.open(local_url) @@ -270,7 +274,7 @@ def parse_pdf( 执行从 pdf 转换到 json、md 的过程,输出 md 和 json 文件到 pdf 文件所在的目录 :param pdf_path: .pdf 文件的路径,可以是相对路径,也可以是绝对路径 - :param parse_method: 解析方法, 共 auto、ocr、txt 三种,默认 auto,如果效果不好,可以尝试 ocr + :param parse_method: 解析方法, 共 auto、ocr两种,默认 auto。auto会根据文件类型选择TXT模式或者OCR模式解析。ocr会直接使用OCR模式。 :param model_json_path: 已经存在的模型数据文件,如果为空则使用内置模型,pdf 和 model_json 务必对应 """ try: @@ -294,8 +298,6 @@ def parse_pdf( if parse_method == "auto": jso_useful_key = {"_pdf_type": "", "model_list": model_json} pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer) - elif parse_method == "txt": - pipe = TXTPipe(pdf_bytes, model_json, image_writer) elif parse_method == "ocr": pipe = OCRPipe(pdf_bytes, model_json, image_writer) else: @@ -358,8 +360,11 @@ def load( Returns: List[Document]: list of documents. """ - - md_content = self.parse_pdf(file_path, "auto") + if self.enable_mandatory_ocr: + parse_method = "ocr" + else: + parse_method = "auto" + md_content = self.parse_pdf(file_path, parse_method) logger.info(f"[PaiPDFReader] successfully processed pdf file {file_path}.") docs = [] if metadata: