Skip to content

Commit

Permalink
Fix table recognition bug
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 18, 2024
1 parent 5a4efb1 commit 893f586
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.6.5"
version = "0.6.6"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
authors = ["Vik Paruchuri <[email protected]>"]
readme = "README.md"
Expand Down
20 changes: 13 additions & 7 deletions surya/model/table_rec/processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import random
from typing import Dict, Union, Optional, List, Iterable

import cv2
Expand Down Expand Up @@ -178,18 +179,18 @@ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self._in_target_context_manager = False
self.max_input_boxes = kwargs.get("max_input_boxes", 256)
self.max_input_boxes = kwargs.get("max_input_boxes", settings.TABLE_REC_MAX_ROWS)
self.extra_input_boxes = kwargs.get("extra_input_boxes", 32)

def resize_boxes(self, img, boxes):
width, height = img.size
box_width, box_height = self.box_size
for box in boxes:
# Rescale to 0-1024
box[0] = box[0] / width * box_width
box[1] = box[1] / height * box_height
box[2] = box[2] / width * box_width
box[3] = box[3] / height * box_height
box[0] = math.ceil(box[0] / width * box_width)
box[1] = math.ceil(box[1] / height * box_height)
box[2] = math.floor(box[2] / width * box_width)
box[3] = math.floor(box[3] / height * box_height)

if box[0] < 0:
box[0] = 0
Expand All @@ -200,6 +201,9 @@ def resize_boxes(self, img, boxes):
if box[3] > box_height:
box[3] = box_height

boxes = [b for b in boxes if b[3] > b[1] and b[2] > b[0]]
boxes = [b for b in boxes if (b[3] - b[1]) * (b[2] - b[0]) > 10]

return boxes

def __call__(self, *args, **kwargs):
Expand All @@ -212,9 +216,11 @@ def __call__(self, *args, **kwargs):
args = args[1:]

for i in range(len(boxes)):
random.seed(1)
if len(boxes[i]) > self.max_input_boxes:
downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes)
boxes[i] = boxes[i][::downsample_ratio]
downsample_ratio = self.max_input_boxes / len(boxes[i])
boxes[i] = [b for b in boxes[i] if random.random() < downsample_ratio]
boxes[i] = boxes[i][:self.max_input_boxes]

new_boxes = []
max_len = self.max_input_boxes + self.extra_input_boxes
Expand Down
5 changes: 5 additions & 0 deletions table_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,17 @@ def main():
cols = [l.bbox for l in pred.cols]
row_labels = [f"Row {l.row_id}" for l in pred.rows]
col_labels = [f"Col {l.col_id}" for l in pred.cols]
cells = [l.bbox for l in pred.cells]

rc_image = copy.deepcopy(table_img)
rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue")
rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red")
rc_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png"))

cell_image = copy.deepcopy(table_img)
cell_image = draw_bboxes_on_image(cells, cell_image, color="green")
cell_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png"))

with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(table_predictions, f, ensure_ascii=False)

Expand Down

0 comments on commit 893f586

Please sign in to comment.