Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bad OCR detection model #268

Merged
merged 14 commits into from
Dec 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,10 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr tesseract-ocr-eng
- name: Install python dependencies
run: |
pip install poetry
poetry install
poetry remove torch
poetry run pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Run detection benchmark test
run: |
poetry run python benchmark/detection.py --max 2
Expand Down
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Integration test

on: [push]

env:
TORCH_DEVICE: "cpu"

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install apt dependencies
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr tesseract-ocr-eng
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: poetry run pytest
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ The `results.json` file will contain a json dictionary where the keys are the in
- `polygon` - the polygon for the text line in (x1, y1), (x2, y2), (x3, y3), (x4, y4) format. The points are in clockwise order from the top left.
- `position` - the reading order of the box.
- `label` - the label for the bbox. One of `Caption`, `Footnote`, `Formula`, `List-item`, `Page-footer`, `Page-header`, `Picture`, `Figure`, `Section-header`, `Table`, `Form`, `Table-of-contents`, `Handwriting`, `Text`, `Text-inline-math`.
- `top_k` - the top-k other potential labels for the box. A dictionary with labels as keys and confidences as values.
- `page` - the page number in the file
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.

Expand Down
2 changes: 0 additions & 2 deletions detect_layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import time

import pypdfium2 # Causes a warning if not the top import
import argparse
import copy
import json
Expand Down
59 changes: 55 additions & 4 deletions ocr_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import tempfile
from typing import List

import pypdfium2
Expand All @@ -15,6 +16,7 @@
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.model.table_rec.model import load_model as load_table_model
from surya.model.table_rec.processor import load_processor as load_table_processor
from surya.model.ocr_error.model import load_model as load_ocr_error_model, load_tokenizer as load_ocr_error_processor
from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image
from surya.ocr import run_ocr
from surya.postprocessing.text import draw_text_on_image
Expand All @@ -24,7 +26,9 @@
from surya.schema import OCRResult, TextDetectionResult, LayoutResult, TableResult
from surya.settings import settings
from surya.tables import batch_table_recognition
from surya.postprocessing.util import rescale_bboxes, rescale_bbox
from surya.postprocessing.util import rescale_bbox
from pdftext.extraction import plain_text_output
from surya.ocr_error import batch_ocr_error_detection


@st.cache_resource()
Expand All @@ -46,6 +50,39 @@ def load_layout_cached():
def load_table_cached():
return load_table_model(), load_table_processor()

@st.cache_resource()
def load_ocr_error_cached():
return load_ocr_error_model(), load_ocr_error_processor()


def run_ocr_errors(pdf_file, page_count, sample_len=512, max_samples=10, max_pages=15):
with tempfile.NamedTemporaryFile(suffix=".pdf") as f:
f.write(pdf_file.getvalue())
f.seek(0)

# Sample the text from the middle of the PDF
page_middle = page_count // 2
page_range = range(max(page_middle - max_pages, 0), min(page_middle + max_pages, page_count))
text = plain_text_output(f.name, page_range=page_range)

sample_gap = len(text) // max_samples
if len(text) == 0 or sample_gap == 0:
return "This PDF has no text or very little text", ["no text"]

if sample_gap < sample_len:
sample_gap = sample_len

# Split the text into samples for the model
samples = []
for i in range(0, len(text), sample_gap):
samples.append(text[i:i + sample_len])

results = batch_ocr_error_detection(samples, ocr_error_model, ocr_error_processor)
label = "This PDF has good text."
if results.labels.count("bad") / len(results.labels) > .2:
label = "This PDF may have garbled or bad OCR text."
return label, results.labels


def text_detection(img) -> (Image.Image, TextDetectionResult):
pred = batch_text_detection([img], det_model, det_processor)[0]
Expand Down Expand Up @@ -139,13 +176,16 @@ def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI):
)
png = list(renderer)[0]
png_image = png.convert("RGB")
doc.close()
return png_image


@st.cache_data()
def page_count(pdf_file):
def page_counter(pdf_file):
doc = open_pdf(pdf_file)
return len(doc)
doc_len = len(doc)
doc.close()
return doc_len


st.set_page_config(layout="wide")
Expand All @@ -155,6 +195,7 @@ def page_count(pdf_file):
rec_model, rec_processor = load_rec_cached()
layout_model, layout_processor = load_layout_cached()
table_model, table_processor = load_table_cached()
ocr_error_model, ocr_error_processor = load_ocr_error_cached()


st.markdown("""
Expand All @@ -179,8 +220,9 @@ def page_count(pdf_file):

filetype = in_file.type
whole_image = False
page_count = None
if "pdf" in filetype:
page_count = page_count(in_file)
page_count = page_counter(in_file)
page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count)

pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI)
Expand All @@ -194,6 +236,7 @@ def page_count(pdf_file):
text_rec = st.sidebar.button("Run OCR")
layout_det = st.sidebar.button("Run Layout Analysis")
table_rec = st.sidebar.button("Run Table Rec")
ocr_errors = st.sidebar.button("Run bad PDF text detection")
use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.")
skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.")

Expand Down Expand Up @@ -233,5 +276,13 @@ def page_count(pdf_file):
st.image(table_img, caption="Table Recognition", use_container_width=True)
st.json([p.model_dump() for p in pred], expanded=True)

if ocr_errors:
if "pdf" not in filetype:
st.error("This feature only works with PDFs.")
label, results = run_ocr_errors(in_file, page_count)
with col1:
st.write(label)
st.json(results)

with col2:
st.image(pil_image, caption="Uploaded Image", use_container_width=True)
Loading
Loading