diff --git a/OCR/ocr/api.py b/OCR/ocr/api.py index 444c834e..54ced5c7 100644 --- a/OCR/ocr/api.py +++ b/OCR/ocr/api.py @@ -4,8 +4,10 @@ import json import cv2 as cv import numpy as np +import asyncio -from fastapi import FastAPI, UploadFile, Form + +from fastapi import FastAPI, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from ocr.services.image_ocr import ImageOCR @@ -60,15 +62,46 @@ async def image_alignment(source_image: str = Form(), segmentation_template: str @app.post("/image_file_to_text/") async def image_file_to_text(source_image: UploadFile, segmentation_template: UploadFile, labels: str = Form()): - source_image_np = np.frombuffer(await source_image.read(), np.uint8) - source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR) - - segmentation_template_np = np.frombuffer(await segmentation_template.read(), np.uint8) - segmentation_template_img = cv.imdecode(segmentation_template_np, cv.IMREAD_COLOR) - - loaded_json = json.loads(labels) - segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json) - results = ocr.image_to_text(segments) + try: + source_image_np = np.frombuffer(await source_image.read(), np.uint8) + source_image_img = cv.imdecode(source_image_np, cv.IMREAD_COLOR) + + if source_image_img is None: + raise HTTPException( + status_code=422, detail="Failed to decode source image. Ensure the file is a valid image format." + ) + + segmentation_template_np = np.frombuffer(await segmentation_template.read(), np.uint8) + segmentation_template_img = cv.imdecode(segmentation_template_np, cv.IMREAD_COLOR) + + if segmentation_template_img is None: + raise HTTPException( + status_code=422, + detail="Failed to decode segmentation template. Ensure the file is a valid image format.", + ) + + if source_image_img.shape[:2] != segmentation_template_img.shape[:2]: + raise HTTPException( + status_code=400, + detail="Dimension mismatch between source image and segmentation template. Both images must have the same width and height.", + ) + + loaded_json = json.loads(labels) + + segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json) + results = ocr.image_to_text(segments) + + except json.JSONDecodeError: + raise HTTPException( + status_code=422, detail="Failed to parse labels JSON. Ensure the labels are in valid JSON format." + ) + except asyncio.TimeoutError: + raise HTTPException(status_code=504, detail="The request timed out. Please try again.") + except HTTPException as e: + raise e + except Exception as e: + print(f"Unexpected error occurred: {str(e)}") + raise HTTPException(status_code=500, detail="An unexpected server error occurred.") return results diff --git a/OCR/tests/api_test.py b/OCR/tests/api_test.py index cb10e39b..c386d5fb 100644 --- a/OCR/tests/api_test.py +++ b/OCR/tests/api_test.py @@ -1,6 +1,9 @@ import base64 import os import json +from unittest import mock +import asyncio + from fastapi.testclient import TestClient @@ -13,6 +16,8 @@ segmentation_template_path = os.path.join(path, "./assets/form_segmention_template.png") source_image_path = os.path.join(path, "./assets/form_filled.png") labels_path = os.path.join(path, "./assets/labels.json") +invalid_dimension_path = os.path.join(path, "./assets/invalid_dimension_template.png") +invalid_image_file_path = os.path.join(path, "./assets/invalid_image_file.png") class TestAPI: @@ -104,3 +109,112 @@ def test_image_to_text_with_padding(self): response_json = response.json() assert response_json["nbs_patient_id"][0] == "SIENNA HAMPTON" assert response_json["nbs_cas_id"][0] == "123555" + + def test_invalid_source_image_format(self): + with ( + open(segmentation_template_path, "rb") as segmentation_template_file, + open(invalid_image_file_path, "rb") as source_image_file, # using invalid image + open(labels_path, "r") as labels, + ): + label_data = json.load(labels) + files_to_send = [ + ("source_image", source_image_file), + ("segmentation_template", segmentation_template_file), + ] + + response = client.post( + url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)} + ) + + assert response.status_code == 422 + response_json = response.json() + assert ( + response_json["detail"] == "Failed to decode source image. Ensure the file is a valid image format." + ) + + def test_invalid_segmentation_template_format(self): + with ( + open(invalid_image_file_path, "rb") as segmentation_template_file, # using a invalid image + open(source_image_path, "rb") as source_image_file, + open(labels_path, "r") as labels, + ): + label_data = json.load(labels) + files_to_send = [ + ("source_image", source_image_file), + ("segmentation_template", segmentation_template_file), + ] + + response = client.post( + url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)} + ) + + assert response.status_code == 422 + assert ( + response.json()["detail"] + == "Failed to decode segmentation template. Ensure the file is a valid image format." + ) + + def test_dimension_mismatch(self): + with ( + open(source_image_path, "rb") as source_image_file, + open(invalid_dimension_path, "rb") as invalid_dimension_file, # using a file with separate dimensions + open(labels_path, "r") as labels, + ): + label_data = json.load(labels) + files_to_send = [ + ("source_image", source_image_file), + ("segmentation_template", invalid_dimension_file), + ] + + response = client.post( + url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)} + ) + + assert response.status_code == 400 + assert ( + response.json()["detail"] + == "Dimension mismatch between source image and segmentation template. Both images must have the same width and height." + ) + + def test_invalid_json_labels(self): + with ( + open(source_image_path, "rb") as source_image_file, + open(segmentation_template_path, "rb") as segmentation_template_file, + ): + invalid_label_data = "{invalid: json}" # file with invalid json format + files_to_send = [ + ("source_image", source_image_file), + ("segmentation_template", segmentation_template_file), + ] + + response = client.post( + url="/image_file_to_text", files=files_to_send, data={"labels": invalid_label_data} + ) + + assert response.status_code == 422 + assert ( + response.json()["detail"] + == "Failed to parse labels JSON. Ensure the labels are in valid JSON format." + ) + + def test_timeout_error_simulation(self): + with ( + open(source_image_path, "rb") as source_image_file, + open(segmentation_template_path, "rb") as segmentation_template_file, + open(labels_path, "r") as labels, + ): + label_data = json.load(labels) + files_to_send = [ + ("source_image", source_image_file), + ("segmentation_template", segmentation_template_file), + ] + + with mock.patch( + "ocr.services.image_segmenter.ImageSegmenter.segment", side_effect=asyncio.TimeoutError + ): # mocks a invoked segment call + response = client.post( + url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)} + ) + + assert response.status_code == 504 + assert response.json()["detail"] == "The request timed out. Please try again." diff --git a/OCR/tests/assets/invalid_dimension_template.png b/OCR/tests/assets/invalid_dimension_template.png new file mode 100644 index 00000000..983eb4ac Binary files /dev/null and b/OCR/tests/assets/invalid_dimension_template.png differ diff --git a/OCR/tests/assets/invalid_image_file.png b/OCR/tests/assets/invalid_image_file.png new file mode 100644 index 00000000..7a8aee90 Binary files /dev/null and b/OCR/tests/assets/invalid_image_file.png differ