diff --git a/OCR/README.md b/OCR/README.md index 7d35a493..301d2c0b 100644 --- a/OCR/README.md +++ b/OCR/README.md @@ -84,35 +84,38 @@ poetry run api ### Test Data Sets -You can also run the script pytest run reportvision-dataset-1/medical_report_import.py to pull in all relevant data. - - -### Run E2E Benchmark Main -This will: -1.Segment and Run OCR on a folder of images using given segmentation template and labels file. -2.Compare OCR outputs to ground truth by searching for matching file names . -3.Writes metrics(Confidence,Raw Distance,Hamming Distance, Levenshtein Distance) as well as total metrics to a csv file. - - -To Run: -Locate file benchmark_main.py -Ensure all the paths/folders exist -https://drive.google.com/drive/folders/1WS2FYn0BTxWv0juh7lblzdMaFlI7zbDd?usp=sharing (This link for all segmentation/labels files) -Ensure ground_truth folder and files exist -Ensure labels.json is in the correct format see(tax_form_segmented_labels.json as an example) -When running make sure to pass arguments in this order: - -/path/to/image/folder (path to the original image files which we need to run ocr on) -/path/to/segmentation_template.png(single_file) -/path/to/labels.json(single file) -/path/to/output/folder (path to folder where the output would be. This should exist but can be empty) -/path/to/ground/truth_folder(path to folder for metrics that we would compare against) -/path/to/csv_out_folder(path to folder where all metrics would be. This should exist but can be empty) -the last arguement is a number 1 for running segmentation and ocr 2 for metrics analysis and 3 for running both - -Notes: -benchmark takes one second per segment for OCR please be patient or set a counter to limit the number of files processed -Only one segment can be inputted at a time +You can also run the script `pytest run reportvision-dataset-1/medical_report_import.py` to pull in all relevant data. + + +### Run end-to-end benchmarking + +End-to-end benchmarking scripts can: + +1. Segment and run OCR on a folder of images using given segmentation template and labels file. +2. Compare OCR outputs to ground truth data based on matching file names. +3. Write metrics (confidence, raw distance, Hamming distance, Levenshtein distance) as well as total metrics to a CSV file. + + +To run benchmarking: + +1. Locate file `benchmark_main.py` +2. Ensure all the paths/folders exist by downloading from [Google Drive for all segmentation/label files](https://drive.google.com/drive/folders/1WS2FYn0BTxWv0juh7lblzdMaFlI7zbDd?usp=sharing) +3. Ensure `ground_truth` folder and files exist +4. Ensure `labels.json` is in the correct format (see `tax_form_segmented_labels.json` as an example) +5. When running make sure to pass arguments in this order: + +* `/path/to/image/folder` (path to the original image files which we need to run OCR on) +* `/path/to/segmentation_template.png` (single file) +* `/path/to/labels.json` (single file) +* `/path/to/output/folder` (path to folder where the output would be. This should exist but can be empty) +* `/path/to/ground/truth_folder` (path to folder for metrics that we would compare against) +* `/path/to/csv_out_folder` (path to folder where all metrics would be. This should exist but can be empty) + +By default, segmentation, OCR, and metrics computation are all run together. To disable one or the other, pass the `--no-ocr` or `--no-metrics` flags. You can change the backend model by passing `--model=...` as well. + +Run notes: +* Benchmark takes one second per segment for OCR using the default `trocr` model. Please be patient or set a counter to limit the number of files processed. +* Only one segment can be input at a time ### Dockerized Development diff --git a/OCR/tests/benchmark_main.py b/OCR/benchmark_main.py similarity index 52% rename from OCR/tests/benchmark_main.py rename to OCR/benchmark_main.py index 919e81b8..203d151f 100644 --- a/OCR/tests/benchmark_main.py +++ b/OCR/benchmark_main.py @@ -1,10 +1,16 @@ import argparse -from tests.batch_segmentation import BatchSegmentationOCR -from tests.batch_metrics import BatchMetricsAnalysis +from ocr.services.batch_segmentation import BatchSegmentationOCR +from ocr.services.batch_metrics import BatchMetricsAnalysis -def main(): - parser = argparse.ArgumentParser(description="Run OCR and metrics analysis.") +from ocr.services.tesseract_ocr import TesseractOCR, PSM +from ocr.services.image_ocr import ImageOCR + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run OCR and metrics analysis.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) parser.add_argument("image_folder", help="Path to the folder containing image files.") parser.add_argument("segmentation_template", help="Path to the segmentation template.") parser.add_argument("labels_path", help="Path to the labels file (JSON).") @@ -12,26 +18,31 @@ def main(): parser.add_argument("ground_truth_folder", help="Path to the folder with ground truth JSON files.") parser.add_argument("csv_output_folder", help="Path to the folder where CSV metrics will be saved.") parser.add_argument( - "run_type", - type=int, - choices=[1, 2, 3], - help="Choose run type: 1 for Segmentation Only, 2 for Metrics Only, 3 for Both.", + "--ocr", + action=argparse.BooleanOptionalAction, + default=True, + help="Run (or don't run) segmentation and OCR analysis", + ) + parser.add_argument( + "--metrics", action=argparse.BooleanOptionalAction, default=True, help="Run (or don't run) metrics analysis" + ) + parser.add_argument( + "--model", choices=["tesseract", "trocr"], default="trocr", help="OCR model to run for `--ocr` option." ) - args = parser.parse_args() + return parser.parse_args() + +def main(): + args = parse_args() ocr_results = None - if args.run_type == 1: # Segmentation Only - print("Running segmentation and OCR...") + if args.ocr: + print(f"Running segmentation and OCR using {args.model}...") ocr_results = run_segmentation_and_ocr(args) - elif args.run_type == 2: # Metrics Only + if args.metrics: print("Running metrics analysis...") - run_metrics_analysis(args, ocr_results=None) - elif args.run_type == 3: - print("Running both segmentation,ocr and metrics analysis...") - ocr_results = run_segmentation_and_ocr(args) - run_metrics_analysis(args, ocr_results) + run_metrics_analysis(args, ocr_results=ocr_results) def run_segmentation_and_ocr(args): @@ -39,8 +50,19 @@ def run_segmentation_and_ocr(args): Runs segmentation and OCR processing. Returns OCR results with processing time. """ + + model = None + + if args.model == "tesseract": + # We are doing segmentation (not tesseract) so: + # * Disable border rejection of text too close to the edge of the image + # * Enforce single-line mode for tesseract + model = TesseractOCR(psm=PSM.SINGLE_LINE, variables=dict(tessedit_image_border="0")) + elif args.model == "trocr": + model = ImageOCR() + segmentation_ocr = BatchSegmentationOCR( - args.image_folder, args.segmentation_template, args.labels_path, args.output_folder + args.image_folder, args.segmentation_template, args.labels_path, args.output_folder, model=model ) ocr_results = segmentation_ocr.process_images() print(f"OCR results saved to: {args.output_folder}") diff --git a/OCR/tests/batch_metrics.py b/OCR/ocr/services/batch_metrics.py similarity index 100% rename from OCR/tests/batch_metrics.py rename to OCR/ocr/services/batch_metrics.py diff --git a/OCR/tests/batch_segmentation.py b/OCR/ocr/services/batch_segmentation.py similarity index 95% rename from OCR/tests/batch_segmentation.py rename to OCR/ocr/services/batch_segmentation.py index f16dc153..d2d80f34 100644 --- a/OCR/tests/batch_segmentation.py +++ b/OCR/ocr/services/batch_segmentation.py @@ -2,16 +2,20 @@ import json import time import csv + from ocr.services.image_segmenter import ImageSegmenter from ocr.services.image_ocr import ImageOCR class BatchSegmentationOCR: - def __init__(self, image_folder, segmentation_template, labels_path, output_folder): + def __init__(self, image_folder, segmentation_template, labels_path, output_folder, model=None): self.image_folder = image_folder self.segmentation_template = segmentation_template self.labels_path = labels_path self.output_folder = output_folder + self.model = model + if self.model is None: + self.model = ImageOCR() os.makedirs(self.output_folder, exist_ok=True) def process_images(self): @@ -19,7 +23,7 @@ def process_images(self): Processes all images and returns OCR results with time taken. """ segmenter = ImageSegmenter() - ocr = ImageOCR() + ocr = self.model results = [] time_dict = {} diff --git a/OCR/ocr/services/image_segmenter.py b/OCR/ocr/services/image_segmenter.py index 115ab0b2..6378727b 100644 --- a/OCR/ocr/services/image_segmenter.py +++ b/OCR/ocr/services/image_segmenter.py @@ -31,7 +31,7 @@ def segment_by_mask_then_crop(raw_image, segmentation_template, labels, debug) - segmentation_template = np.array(segmentation_template, copy=True) color = tuple(map(int, reversed(color.split(",")))) # create a mask for that color - mask = np.all(segmentation_template == color, axis=2).astype(int) + mask = np.all(segmentation_template == color, axis=2).astype(np.uint8) # add a third dimension to the mask mask = mask[:, :, np.newaxis] @@ -100,11 +100,11 @@ def load_and_segment(self, raw_image_path, segmentation_template_path, labels_pa ): raise FileNotFoundError("One or more input files do not exist.") - raw_image = cv.imread(raw_image_path) + raw_image = cv.imread(raw_image_path, cv.IMREAD_COLOR) if raw_image is None: raise ValueError(f"Failed to open image file: {raw_image_path}") - segmentation_template = cv.imread(segmentation_template_path) + segmentation_template = cv.imread(segmentation_template_path, cv.IMREAD_COLOR) if segmentation_template is None: raise ValueError(f"Failed to open image file: {segmentation_template_path}") diff --git a/OCR/ocr/services/metrics_analysis.py b/OCR/ocr/services/metrics_analysis.py index cea663e2..a6c7109c 100644 --- a/OCR/ocr/services/metrics_analysis.py +++ b/OCR/ocr/services/metrics_analysis.py @@ -48,11 +48,11 @@ def raw_distance(ocr_text, ground_truth): def hamming_distance(ocr_text, ground_truth): if len(ocr_text) != len(ground_truth): raise ValueError("Strings must be of the same length to calculate Hamming distance.") - return Levenshtein.hamming(ocr_text, ground_truth) + return Levenshtein.hamming(ocr_text.upper(), ground_truth.upper()) @staticmethod def levenshtein_distance(ocr_text, ground_truth): - return Levenshtein.distance(ocr_text, ground_truth) + return Levenshtein.distance(ocr_text.upper(), ground_truth.upper()) def extract_values_from_json(self, json_data): if json_data is None: diff --git a/OCR/ocr/services/tesseract_ocr.py b/OCR/ocr/services/tesseract_ocr.py index 294d80f1..9ce4a78e 100644 --- a/OCR/ocr/services/tesseract_ocr.py +++ b/OCR/ocr/services/tesseract_ocr.py @@ -1,11 +1,22 @@ import os import tesserocr +from tesserocr import PSM import numpy as np from PIL import Image class TesseractOCR: + def __init__(self, psm=PSM.AUTO, variables=dict()): + """ + Initialize the tesseract OCR model. + + `psm` (int): an enum (from `PSM`) that defines tesseract's page segmentation mode. Default is `AUTO`. + `variables` (dict): a dict to customize tesseract's behavior with internal variables + """ + self.psm = psm + self.variables = variables + @staticmethod def _guess_tessdata_path(wanted_lang="eng") -> bytes: """ @@ -52,7 +63,7 @@ def _guess_tessdata_path(wanted_lang="eng") -> bytes: def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, tuple[str, float]]: digitized: dict[str, tuple[str, float]] = {} - with tesserocr.PyTessBaseAPI(path=self._guess_tessdata_path()) as api: + with tesserocr.PyTessBaseAPI(psm=self.psm, variables=self.variables, path=self._guess_tessdata_path()) as api: for label, image in segments.items(): if image is None: continue diff --git a/OCR/tests/benchmark_test.py b/OCR/tests/benchmark_test.py index 5ba86473..e076b73c 100644 --- a/OCR/tests/benchmark_test.py +++ b/OCR/tests/benchmark_test.py @@ -7,6 +7,7 @@ from Levenshtein import distance, ratio from ocr.services.image_ocr import ImageOCR +from ocr.services.tesseract_ocr import TesseractOCR from PIL import Image, ImageDraw, ImageFont path = os.path.dirname(__file__) @@ -95,6 +96,7 @@ def generate_exact_segments( class TestBenchmark: ocr = ImageOCR() + tess = TesseractOCR() sample_size = 10 test_cases = [ @@ -115,9 +117,10 @@ class TestBenchmark: @pytest.mark.benchmark(group="OCR Model Performance", min_rounds=1) @pytest.mark.parametrize("name,segments", test_cases) - def test_ocr_english_sentences(self, name, segments, benchmark): + @pytest.mark.parametrize("model", (ocr, tess)) + def test_ocr_english_sentences(self, name, segments, model, benchmark): print("\n", name) - results = benchmark(self.ocr.image_to_text, segments) + results = benchmark(model.image_to_text, segments) actual_labels = [x.lower() for x in list(results.keys())] predicted_labels = [x[0].lower() for x in list(results.values())]