diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml new file mode 100644 index 00000000..9f08df6c --- /dev/null +++ b/.github/workflows/lint.yaml @@ -0,0 +1,21 @@ +name: Lint checks +on: + workflow_call: + workflow_dispatch: + pull_request: + branches: + - '**' + push: + branches: + - main + +jobs: + python: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - name: Ruff checks + run: pipx run ruff check --output-format=github + - name: Ruff format + run: pipx run ruff format --check --diff diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 00000000..72699e49 --- /dev/null +++ b/.ruff.toml @@ -0,0 +1 @@ +exclude = ['dedupe/*'] diff --git a/OCR/ocr/azure_form_ex.py b/OCR/ocr/azure_form_ex.py index 8925e2ab..43ed9ed6 100644 --- a/OCR/ocr/azure_form_ex.py +++ b/OCR/ocr/azure_form_ex.py @@ -8,16 +8,14 @@ load_dotenv() -endpoint = os.getenv('FORM_RECOGNIZER_ENDPOINT') -key = os.getenv('FORM_RECOGNIZER_KEY') -file_path = os.getenv('FORM_RECOGNIZER_FILE_PATH') - +endpoint = os.getenv("FORM_RECOGNIZER_ENDPOINT") +key = os.getenv("FORM_RECOGNIZER_KEY") +file_path = os.getenv("FORM_RECOGNIZER_FILE_PATH") client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(key)) - try: with open(file_path, "rb") as f: poller = client.begin_analyze_document("prebuilt-document", document=f) @@ -25,19 +23,18 @@ except Exception as e: print(f"An error occurred: {e}") else: - with open('output.json', 'w') as json_file: + with open("output.json", "w") as json_file: for idx, style in enumerate(result.styles): if style.is_handwritten: print(f"Document contains handwritten content: {style.confidence}") for page in result.pages: print(f"Page number: {page.page_number} has {len(page.lines)} lines and {len(page.words)} words.") - lines = page.lines[:] + lines = page.lines[:] lines_content = [line.content for line in lines] kv_pairs = {} for line in lines_content: - if ':' in line: - key, value = line.split(':', 1) + if ":" in line: + key, value = line.split(":", 1) kv_pairs[key.strip()] = value.strip() json.dump(kv_pairs, json_file, indent=4) - diff --git a/OCR/ocr/main.py b/OCR/ocr/main.py index bfde2fff..0d6598cd 100644 --- a/OCR/ocr/main.py +++ b/OCR/ocr/main.py @@ -7,10 +7,7 @@ def main(): - - segmentation_template = os.path.join( - path, "../tests/assets/form_segmention_template.png" - ) + segmentation_template = os.path.join(path, "../tests/assets/form_segmention_template.png") raw_image = os.path.join(path, "../tests/assets/form_filled.png") labels_path = os.path.join(path, "../tests/assets/labels.json") diff --git a/OCR/ocr/services/image_ocr.py b/OCR/ocr/services/image_ocr.py index b9be7fe0..b11701a5 100644 --- a/OCR/ocr/services/image_ocr.py +++ b/OCR/ocr/services/image_ocr.py @@ -3,7 +3,6 @@ class ImageOCR: - def __init__(self, model="microsoft/trocr-base-printed"): self.processor = TrOCRProcessor.from_pretrained(model) self.model = VisionEncoderDecoderModel.from_pretrained(model) @@ -11,15 +10,10 @@ def __init__(self, model="microsoft/trocr-base-printed"): def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, str]: digitized: dict[str, str] = {} for label, image in segments.items(): - - pixel_values = self.processor( - images=image, return_tensors="pt" - ).pixel_values + pixel_values = self.processor(images=image, return_tensors="pt").pixel_values generated_ids = self.model.generate(pixel_values) - generated_text = self.processor.batch_decode( - generated_ids, skip_special_tokens=True - ) + generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True) digitized[label] = generated_text[0] return digitized diff --git a/OCR/ocr/services/image_segmenter.py b/OCR/ocr/services/image_segmenter.py index 3e7ccb16..b5b66cc4 100644 --- a/OCR/ocr/services/image_segmenter.py +++ b/OCR/ocr/services/image_segmenter.py @@ -6,7 +6,6 @@ class ImageSegmenter: def __init__(self, raw_image, segmentation_template, labels): - if not os.path.isfile(raw_image) or not os.path.isfile(segmentation_template): raise FileNotFoundError("One or more input files do not exist.") @@ -31,16 +30,12 @@ def segment(self) -> dict[str, np.ndarray]: indices = np.where(np.all(self.segmentation_template == color, axis=-1)) # if there are no matching pixels if indices[0].size == 0: - raise ValueError( - f"No pixels found for color {color} in segmentation template." - ) + raise ValueError(f"No pixels found for color {color} in segmentation template.") # if there are matching pixels if indices[0].size > 0: # Find the x-y coordinates y_min, y_max = indices[0].min(), indices[0].max() x_min, x_max = indices[1].min(), indices[1].max() # crop the area and store the image in the dict - self.segments[label] = self.raw_image[ - y_min : y_max + 1, x_min : x_max + 1 - ] + self.segments[label] = self.raw_image[y_min : y_max + 1, x_min : x_max + 1] return self.segments diff --git a/OCR/pyproject.toml b/OCR/pyproject.toml index 65ad21a8..b18ab13f 100644 --- a/OCR/pyproject.toml +++ b/OCR/pyproject.toml @@ -23,3 +23,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] main = "ocr.main:main" + +[tool.ruff] +line-length = 118 +target-version = "py310" diff --git a/OCR/tests/ocr_test.py b/OCR/tests/ocr_test.py index e27a5b17..bc6edb95 100644 --- a/OCR/tests/ocr_test.py +++ b/OCR/tests/ocr_test.py @@ -12,7 +12,6 @@ class TestOCR: - def test_ocr_printed(self): segmenter = ImageSegmenter(raw_image, segmentation_template, labels_path) ocr = ImageOCR() @@ -23,9 +22,7 @@ def test_ocr_printed(self): assert results["nbs_cas_id"] == "SIENNA HAMPTON" def test_ocr_handwritten(self): - segmenter = ImageSegmenter( - raw_image_handwritten, segmentation_template, labels_path - ) + segmenter = ImageSegmenter(raw_image_handwritten, segmentation_template, labels_path) ocr = ImageOCR(model="microsoft/trocr-base-handwritten") results = ocr.image_to_text(segmenter.segment()) diff --git a/OCR/tests/segmentation_template_test.py b/OCR/tests/segmentation_template_test.py index 0cd4f4dd..dfea8c59 100644 --- a/OCR/tests/segmentation_template_test.py +++ b/OCR/tests/segmentation_template_test.py @@ -19,9 +19,7 @@ def setup(self): self.raw_image = raw_image self.segmentation_template = segmentation_template self.labels_path = labels_path - self.segmenter = ImageSegmenter( - self.raw_image, self.segmentation_template, self.labels_path - ) + self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path) def test_segment(self): segments = self.segmenter.segment() @@ -32,11 +30,6 @@ def test_segment(self): for segment in segments.values(): assert isinstance(segment, np.ndarray) - def test_segment_shapes(self): - segments = self.segmenter.segment() - for segment in segments.values(): - assert len(segment.shape) == 3 - def test_segment_shapes(self): expected_shapes = {"nbs_patient_id": (41, 376, 3), "nbs_cas_id": (57, 366, 3)} segments = self.segmenter.segment() @@ -46,9 +39,7 @@ def test_segment_shapes(self): def test_no_matching_pixels(self): segmentation_template = np.zeros((10, 10, 3), dtype=np.uint8) cv.imwrite("no_matching_colors.png", segmentation_template) - segmenter = ImageSegmenter( - self.raw_image, "no_matching_colors.png", self.labels_path - ) + segmenter = ImageSegmenter(self.raw_image, "no_matching_colors.png", self.labels_path) with pytest.raises(ValueError): segmenter.segment() os.remove("no_matching_colors.png")