From a5167be2c0ddc0634ea365a6a4a211d1d6042770 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 13:37:19 +0800 Subject: [PATCH 1/6] fix: correct yolo cls model preprocess --- .gitignore | 2 ++ README.md | 16 ++++++++++------ table_cls/main.py | 11 ++++++++--- table_cls/utils.py | 44 ++++++++++++++++++++++++++++++-------------- 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index e73647b..e5a1c70 100755 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ long1.jpg .DS_Store *.npy outputs/ +/tests/test_files/standard_dataset/ +/lineless_table_rec/images/ diff --git a/README.md b/README.md index 685cdf1..6de39ec 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,11 @@ ### 最近更新 - **2024.10.13** - 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库) -- **2024.10.17** - - 补充最新surya 表格识别测评结果 - **2024.10.22** - - 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection) + - 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection) +- **2024.10.29** + - 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评 + ### 简介 💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。 @@ -57,10 +58,10 @@ | [deepdoctection(rag-flow)](https://github.com/deepdoctection/deepdoctection?tab=readme-ov-file) | 0.59975 | 0.69918 | | [ppstructure_table_master](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.61606 | 0.73892 | | [ppsturcture_table_engine](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.67924 | 0.78653 | -| table_cls + wired_table_rec v1 + lineless_table_rec | 0.68507 | 0.75140 | | [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) | 0.67310 | 0.81210 | | [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 | -| table_cls + wired_table_rec v2 + lineless_table_rec | 0.73702 | 0.80210 | +| table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 | +| table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 | | [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | **0.84481** | **0.91369** | ### 使用建议 @@ -87,6 +88,8 @@ from wired_table_rec import WiredTableRecognition lineless_engine = LinelessTableRecognition() wired_engine = WiredTableRecognition() table_cls = TableCls() +# 分类精度降低,但耗时减少 3/5(0.2s->0.08s) +# table_cls = TableCls(mode="q") img_path = f'images/img14.jpg' cls,elasp = table_cls(img_path) @@ -158,7 +161,8 @@ for i, res in enumerate(result): - [x] 图片小角度偏移修正方法补充 - [x] 增加数据集数量,增加更多评测对比 - [x] 补充复杂场景表格检测和提取,解决旋转和透视导致的低识别率 -- [ ] 优化表格分类器,优化无线表格模型 +- [x] 优化表格分类器 +- [ ] 优化无线表格模型 ### 处理流程 diff --git a/table_cls/main.py b/table_cls/main.py index ca7ab4c..11431e4 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -5,7 +5,7 @@ import numpy as np from PIL import Image -from .utils import InputType, LoadImage, OrtInferSession, ResizePad +from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop cur_dir = Path(__file__).resolve().parent q_cls_model_path = cur_dir / "models" / "table_cls.onnx" @@ -64,10 +64,15 @@ class YoloCls: def __init__(self, model_path): self.table_cls = OrtInferSession(model_path) self.cls = {0: "wireless", 1: "wired"} + self.mean = np.array([0, 0, 0], dtype=np.float32) + self.std = np.array([1, 1, 1], dtype=np.float32) def preprocess(self, img): - img, *_ = ResizePad(img, 640) - img = np.array(img, dtype=np.float32) / 255.0 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = resize_and_center_crop(img, 640) + img = np.array(img, dtype=np.float32) / 255 + img -= self.mean + img /= self.std img = img.transpose(2, 0, 1) # HWC to CHW img = np.expand_dims(img, axis=0) # Add batch dimension, only one image return img diff --git a/table_cls/utils.py b/table_cls/utils.py index 9df30f7..ed31827 100644 --- a/table_cls/utils.py +++ b/table_cls/utils.py @@ -180,17 +180,33 @@ def verify_exist(file_path: Union[str, Path]): raise LoadImageError(f"{file_path} does not exist.") -def ResizePad(img, target_size): - h, w = img.shape[:2] - m = max(h, w) - ratio = target_size / m - new_w, new_h = int(ratio * w), int(ratio * h) - img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) - top = (target_size - new_h) // 2 - bottom = (target_size - new_h) - top - left = (target_size - new_w) // 2 - right = (target_size - new_w) - left - img1 = cv2.copyMakeBorder( - img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) - ) - return img1, new_w, new_h, left, top +def resize_and_center_crop(image: np.ndarray, target_size: int): + """ + Resize the image so that the smallest side is equal to the target size, + then crop the center of the image to the specified target size. + + Args: + image (np.ndarray): Input image as a NumPy array with shape (height, width, channels). + target_size (int): Target size for the smallest side of the image and the output size. + + Returns: + (np.ndarray): Resized and cropped image as a NumPy array. + """ + # 获取输入图像的尺寸 + h, w = image.shape[:2] + + # 计算缩放比例 + scale = target_size / min(h, w) + new_h, new_w = int(h * scale), int(w * scale) + + # 缩放图像 + resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + + # 计算裁剪的起始位置 + i = (new_h - target_size) // 2 + j = (new_w - target_size) // 2 + + # 裁剪图像 + cropped_image = resized_image[i : i + target_size, j : j + target_size] + + return cropped_image From 29c965f7ca0a1e3c6f29550d346209abdc844716 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 13:38:45 +0800 Subject: [PATCH 2/6] chore: fix gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index e5a1c70..e73647b 100755 --- a/.gitignore +++ b/.gitignore @@ -158,5 +158,3 @@ long1.jpg .DS_Store *.npy outputs/ -/tests/test_files/standard_dataset/ -/lineless_table_rec/images/ From a4339fc63749aa49a3bb91a164a90ad2208199d4 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 14:01:20 +0800 Subject: [PATCH 3/6] fix: adjust resize mode --- table_cls/main.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/table_cls/main.py b/table_cls/main.py index 11431e4..554a973 100644 --- a/table_cls/main.py +++ b/table_cls/main.py @@ -5,17 +5,20 @@ import numpy as np from PIL import Image -from .utils import InputType, LoadImage, OrtInferSession, resize_and_center_crop +from .utils import InputType, LoadImage, OrtInferSession cur_dir = Path(__file__).resolve().parent q_cls_model_path = cur_dir / "models" / "table_cls.onnx" yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx" +yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx" class TableCls: def __init__(self, model_type="yolo", model_path=yolo_cls_model_path): if model_type == "yolo": self.table_engine = YoloCls(model_path) + elif model_type == "yolox": + self.table_engine = YoloCls(yolo_cls_x_model_path) else: model_path = q_cls_model_path self.table_engine = QanythingCls(model_path) @@ -64,15 +67,11 @@ class YoloCls: def __init__(self, model_path): self.table_cls = OrtInferSession(model_path) self.cls = {0: "wireless", 1: "wired"} - self.mean = np.array([0, 0, 0], dtype=np.float32) - self.std = np.array([1, 1, 1], dtype=np.float32) def preprocess(self, img): img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img = resize_and_center_crop(img, 640) + img = cv2.resize(img, (640, 640)) img = np.array(img, dtype=np.float32) / 255 - img -= self.mean - img /= self.std img = img.transpose(2, 0, 1) # HWC to CHW img = np.expand_dims(img, axis=0) # Add batch dimension, only one image return img From 1a4974c2f7d3a32e26913791405e8add47410183 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 14:14:47 +0800 Subject: [PATCH 4/6] chore: add table_cls use desc --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6de39ec..9b4c0b3 100644 --- a/README.md +++ b/README.md @@ -87,9 +87,8 @@ from wired_table_rec import WiredTableRecognition lineless_engine = LinelessTableRecognition() wired_engine = WiredTableRecognition() -table_cls = TableCls() -# 分类精度降低,但耗时减少 3/5(0.2s->0.08s) -# table_cls = TableCls(mode="q") +# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型 +table_cls = TableCls() # TableCls(mode="yolox"),TableCls(mode="q") img_path = f'images/img14.jpg' cls,elasp = table_cls(img_path) From 6db1c1cc2c828caefa4753cf02cc8d77d19b4501 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 14:16:13 +0800 Subject: [PATCH 5/6] fix: rm unuse code --- table_cls/utils.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/table_cls/utils.py b/table_cls/utils.py index ed31827..db64b3a 100644 --- a/table_cls/utils.py +++ b/table_cls/utils.py @@ -178,35 +178,3 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray: def verify_exist(file_path: Union[str, Path]): if not Path(file_path).exists(): raise LoadImageError(f"{file_path} does not exist.") - - -def resize_and_center_crop(image: np.ndarray, target_size: int): - """ - Resize the image so that the smallest side is equal to the target size, - then crop the center of the image to the specified target size. - - Args: - image (np.ndarray): Input image as a NumPy array with shape (height, width, channels). - target_size (int): Target size for the smallest side of the image and the output size. - - Returns: - (np.ndarray): Resized and cropped image as a NumPy array. - """ - # 获取输入图像的尺寸 - h, w = image.shape[:2] - - # 计算缩放比例 - scale = target_size / min(h, w) - new_h, new_w = int(h * scale), int(w * scale) - - # 缩放图像 - resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) - - # 计算裁剪的起始位置 - i = (new_h - target_size) // 2 - j = (new_w - target_size) // 2 - - # 裁剪图像 - cropped_image = resized_image[i : i + target_size, j : j + target_size] - - return cropped_image From cd2ea5306afe7ccfd721f2ffc8ea19519aa91f1d Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Tue, 29 Oct 2024 14:19:05 +0800 Subject: [PATCH 6/6] fix: fix desc for table cls --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9b4c0b3..eaccdc6 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ from wired_table_rec import WiredTableRecognition lineless_engine = LinelessTableRecognition() wired_engine = WiredTableRecognition() # 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型 -table_cls = TableCls() # TableCls(mode="yolox"),TableCls(mode="q") +table_cls = TableCls() # TableCls(model_type="yolox"),TableCls(model_type="q") img_path = f'images/img14.jpg' cls,elasp = table_cls(img_path)