diff --git a/README.md b/README.md index c5493e1..4fb5cef 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,7 @@ The `results.json` file will contain a json dictionary where the keys are the in - `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left. - `position` - the reading order of the box. - `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`. + - `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values. - `page` - the page number in the file - `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox. diff --git a/detect_layout.py b/detect_layout.py index 6524af2..1adb5a9 100644 --- a/detect_layout.py +++ b/detect_layout.py @@ -1,6 +1,4 @@ import time - -import pypdfium2 # Causes a warning if not the top import import argparse import copy import json diff --git a/surya/detection.py b/surya/detection.py index 4d0ac9d..7ae2167 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,5 +1,3 @@ -import contextlib - import torch from typing import List, Tuple, Generator diff --git a/surya/layout.py b/surya/layout.py index 476e54e..2378d82 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,5 +1,4 @@ from typing import List - import numpy as np import torch from PIL import Image @@ -136,7 +135,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach() class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach() - probs = torch.nn.functional.softmax(class_logits, dim=-1) + probs = torch.nn.functional.softmax(class_logits, dim=-1).cpu() entropy = torch.special.entr(probs).sum(dim=-1) class_preds = class_logits.argmax(-1) @@ -218,9 +217,12 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_ top_k_probs = [p["top_k_probs"] for p in preds] top_k_indices = [p["top_k_indices"] - model.decoder.config.special_token_count for p in preds] - for z, (poly, label) in enumerate(zip(polygons, labels)): + for z, (poly, label, top_k_prob, top_k_index) in enumerate(zip(polygons, labels, top_k_probs, top_k_indices)): + top_k_dict = { + ID_TO_LABEL.get(int(l)): prob.item() + for (l, prob) in zip(top_k_index, top_k_prob) if l > 0 + } l = ID_TO_LABEL[int(label)] - top_k_dict = {ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0} lb = LayoutBox( polygon=poly, label=l, diff --git a/tests/conftest.py b/tests/conftest.py index 8f37c5b..fb9a864 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ import pytest from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor +from surya.model.layout.model import load_model as load_layout_model +from surya.model.layout.processor import load_processor as load_layout_processor @pytest.fixture(scope="session") def ocr_error_model(): @@ -7,4 +9,14 @@ def ocr_error_model(): ocr_error_p = load_ocr_error_processor() ocr_error_m.processor = ocr_error_p yield ocr_error_m - del ocr_error_m \ No newline at end of file + del ocr_error_m + + +@pytest.fixture(scope="session") +def layout_model(): + layout_m = load_layout_model() + layout_p = load_layout_processor() + layout_m.processor = layout_p + yield layout_m + del layout_m + diff --git a/tests/test_layout.py b/tests/test_layout.py new file mode 100644 index 0000000..bc140e6 --- /dev/null +++ b/tests/test_layout.py @@ -0,0 +1,26 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +from surya.layout import batch_layout_detection +from PIL import Image, ImageDraw + +def test_layout_topk(layout_model): + image = Image.new("RGB", (1024, 1024), "white") + draw = ImageDraw.Draw(image) + draw.text((10, 10), "Hello World", fill="black", font_size=72) + draw.text((10, 200), "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", fill="black", + font_size=24) + + layout_results = batch_layout_detection([image], layout_model, layout_model.processor) + + assert len(layout_results) == 1 + assert layout_results[0].image_bbox == [0, 0, 1024, 1024] + + bboxes = layout_results[0].bboxes + assert len(bboxes) == 2 + + assert bboxes[0].label == "SectionHeader" + assert len(bboxes[0].top_k) == 5 + + assert bboxes[1].label == "Text" + assert len(bboxes[1].top_k) == 5