Skip to content

Commit

Permalink
Merge pull request #61 from RapidAI/correct_cls_preprocess
Browse files Browse the repository at this point in the history
support diff size cls model
  • Loading branch information
Joker1212 authored Oct 29, 2024
2 parents 37fa544 + cd2ea53 commit 517825c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 26 deletions.
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
### 最近更新
- **2024.10.13**
- 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库)
- **2024.10.17**
- 补充最新surya 表格识别测评结果
- **2024.10.22**
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
- **2024.10.29**
- 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评

### 简介
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。

Expand Down Expand Up @@ -57,10 +58,10 @@
| [deepdoctection(rag-flow)](https://github.com/deepdoctection/deepdoctection?tab=readme-ov-file) | 0.59975 | 0.69918 |
| [ppstructure_table_master](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.61606 | 0.73892 |
| [ppsturcture_table_engine](https://github.com/PaddlePaddle/PaddleOCR/tree/main/ppstructure) | 0.67924 | 0.78653 |
| table_cls + wired_table_rec v1 + lineless_table_rec | 0.68507 | 0.75140 |
| [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) | 0.67310 | 0.81210 |
| [RapidTable(SLANet)](https://github.com/RapidAI/RapidTable) | 0.71654 | 0.81067 |
| table_cls + wired_table_rec v2 + lineless_table_rec | 0.73702 | 0.80210 |
| table_cls + wired_table_rec v1 + lineless_table_rec | 0.75288 | 0.82574 |
| table_cls + wired_table_rec v2 + lineless_table_rec | 0.77676 | 0.84580 |
| [RapidTable(SLANet-plus)](https://github.com/RapidAI/RapidTable) | **0.84481** | **0.91369** |

### 使用建议
Expand All @@ -86,7 +87,8 @@ from wired_table_rec import WiredTableRecognition

lineless_engine = LinelessTableRecognition()
wired_engine = WiredTableRecognition()
table_cls = TableCls()
# 默认小yolo模型(0.1s),可切换为精度更高yolox(0.25s),更快的qanything(0.07s)模型
table_cls = TableCls() # TableCls(model_type="yolox"),TableCls(model_type="q")
img_path = f'images/img14.jpg'

cls,elasp = table_cls(img_path)
Expand Down Expand Up @@ -158,7 +160,8 @@ for i, res in enumerate(result):
- [x] 图片小角度偏移修正方法补充
- [x] 增加数据集数量,增加更多评测对比
- [x] 补充复杂场景表格检测和提取,解决旋转和透视导致的低识别率
- [ ] 优化表格分类器,优化无线表格模型
- [x] 优化表格分类器
- [ ] 优化无线表格模型

### 处理流程

Expand Down
10 changes: 7 additions & 3 deletions table_cls/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import numpy as np
from PIL import Image

from .utils import InputType, LoadImage, OrtInferSession, ResizePad
from .utils import InputType, LoadImage, OrtInferSession

cur_dir = Path(__file__).resolve().parent
q_cls_model_path = cur_dir / "models" / "table_cls.onnx"
yolo_cls_model_path = cur_dir / "models" / "yolo_cls.onnx"
yolo_cls_x_model_path = cur_dir / "models" / "yolo_cls_x.onnx"


class TableCls:
def __init__(self, model_type="yolo", model_path=yolo_cls_model_path):
if model_type == "yolo":
self.table_engine = YoloCls(model_path)
elif model_type == "yolox":
self.table_engine = YoloCls(yolo_cls_x_model_path)
else:
model_path = q_cls_model_path
self.table_engine = QanythingCls(model_path)
Expand Down Expand Up @@ -66,8 +69,9 @@ def __init__(self, model_path):
self.cls = {0: "wireless", 1: "wired"}

def preprocess(self, img):
img, *_ = ResizePad(img, 640)
img = np.array(img, dtype=np.float32) / 255.0
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (640, 640))
img = np.array(img, dtype=np.float32) / 255
img = img.transpose(2, 0, 1) # HWC to CHW
img = np.expand_dims(img, axis=0) # Add batch dimension, only one image
return img
Expand Down
16 changes: 0 additions & 16 deletions table_cls/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,3 @@ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
def verify_exist(file_path: Union[str, Path]):
if not Path(file_path).exists():
raise LoadImageError(f"{file_path} does not exist.")


def ResizePad(img, target_size):
h, w = img.shape[:2]
m = max(h, w)
ratio = target_size / m
new_w, new_h = int(ratio * w), int(ratio * h)
img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR)
top = (target_size - new_h) // 2
bottom = (target_size - new_h) - top
left = (target_size - new_w) // 2
right = (target_size - new_w) - left
img1 = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
)
return img1, new_w, new_h, left, top

0 comments on commit 517825c

Please sign in to comment.