Skip to content

Commit

Permalink
Add test for topk
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 19, 2024
1 parent cd795a7 commit 4e60cc5
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ The `results.json` file will contain a json dictionary where the keys are the in
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `position` - the reading order of the box.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
- `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

Expand Down
2 changes: 0 additions & 2 deletions detect_layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import time

import pypdfium2 # Causes a warning if not the top import
import argparse
import copy
import json
Expand Down
2 changes: 0 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib

import torch
from typing import List, Tuple, Generator

Expand Down
10 changes: 6 additions & 4 deletions surya/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List

import numpy as np
import torch
from PIL import Image
Expand Down Expand Up @@ -136,7 +135,7 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_
box_logits = return_dict["bbox_logits"][:current_batch_size, -1, :].detach()
class_logits = return_dict["class_logits"][:current_batch_size, -1, :].detach()

probs = torch.nn.functional.softmax(class_logits, dim=-1)
probs = torch.nn.functional.softmax(class_logits, dim=-1).cpu()
entropy = torch.special.entr(probs).sum(dim=-1)

class_preds = class_logits.argmax(-1)
Expand Down Expand Up @@ -218,9 +217,12 @@ def batch_layout_detection(images: List, model, processor, batch_size=None, top_
top_k_probs = [p["top_k_probs"] for p in preds]
top_k_indices = [p["top_k_indices"] - model.decoder.config.special_token_count for p in preds]

for z, (poly, label) in enumerate(zip(polygons, labels)):
for z, (poly, label, top_k_prob, top_k_index) in enumerate(zip(polygons, labels, top_k_probs, top_k_indices)):
top_k_dict = {
ID_TO_LABEL.get(int(l)): prob.item()
for (l, prob) in zip(top_k_index, top_k_prob) if l > 0
}
l = ID_TO_LABEL[int(label)]
top_k_dict = {ID_TO_LABEL.get(int(l)): prob.item() for (l, prob) in zip(top_k_indices[z], top_k_probs[z]) if l > 0}
lb = LayoutBox(
polygon=poly,
label=l,
Expand Down
14 changes: 13 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import pytest
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
from surya.model.layout.model import load_model as load_layout_model
from surya.model.layout.processor import load_processor as load_layout_processor

@pytest.fixture(scope="session")
def ocr_error_model():
ocr_error_m = load_ocr_error_model()
ocr_error_p = load_ocr_error_processor()
ocr_error_m.processor = ocr_error_p
yield ocr_error_m
del ocr_error_m
del ocr_error_m


@pytest.fixture(scope="session")
def layout_model():
layout_m = load_layout_model()
layout_p = load_layout_processor()
layout_m.processor = layout_p
yield layout_m
del layout_m

26 changes: 26 additions & 0 deletions tests/test_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from surya.layout import batch_layout_detection
from PIL import Image, ImageDraw

def test_layout_topk(layout_model):
image = Image.new("RGB", (1024, 1024), "white")
draw = ImageDraw.Draw(image)
draw.text((10, 10), "Hello World", fill="black", font_size=72)
draw.text((10, 200), "This is a sentence of text.\nNow it is a paragraph.\nA three-line one.", fill="black",
font_size=24)

layout_results = batch_layout_detection([image], layout_model, layout_model.processor)

assert len(layout_results) == 1
assert layout_results[0].image_bbox == [0, 0, 1024, 1024]

bboxes = layout_results[0].bboxes
assert len(bboxes) == 2

assert bboxes[0].label == "SectionHeader"
assert len(bboxes[0].top_k) == 5

assert bboxes[1].label == "Text"
assert len(bboxes[1].top_k) == 5

0 comments on commit 4e60cc5

Please sign in to comment.