Skip to content

Commit

Permalink
Add ruff lint and style checks (#76)
Browse files Browse the repository at this point in the history
* Add ruff lint and style checks

* Fix `ruff check`

* Fix `ruff format`

* Only enable for ocr work

* Revert changes to dedupe folder

We aren't linting or formatting dedupe work.
  • Loading branch information
jonchang authored Apr 17, 2024
1 parent 4411a5a commit c90d1e9
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 44 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Lint checks
on:
workflow_call:
workflow_dispatch:
pull_request:
branches:
- '**'
push:
branches:
- main

jobs:
python:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- name: Ruff checks
run: pipx run ruff check --output-format=github
- name: Ruff format
run: pipx run ruff format --check --diff
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exclude = ['dedupe/*']
17 changes: 7 additions & 10 deletions OCR/ocr/azure_form_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,33 @@
load_dotenv()


endpoint = os.getenv('FORM_RECOGNIZER_ENDPOINT')
key = os.getenv('FORM_RECOGNIZER_KEY')
file_path = os.getenv('FORM_RECOGNIZER_FILE_PATH')

endpoint = os.getenv("FORM_RECOGNIZER_ENDPOINT")
key = os.getenv("FORM_RECOGNIZER_KEY")
file_path = os.getenv("FORM_RECOGNIZER_FILE_PATH")


client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(key))



try:
with open(file_path, "rb") as f:
poller = client.begin_analyze_document("prebuilt-document", document=f)
result = poller.result()
except Exception as e:
print(f"An error occurred: {e}")
else:
with open('output.json', 'w') as json_file:
with open("output.json", "w") as json_file:
for idx, style in enumerate(result.styles):
if style.is_handwritten:
print(f"Document contains handwritten content: {style.confidence}")

for page in result.pages:
print(f"Page number: {page.page_number} has {len(page.lines)} lines and {len(page.words)} words.")
lines = page.lines[:]
lines = page.lines[:]
lines_content = [line.content for line in lines]
kv_pairs = {}
for line in lines_content:
if ':' in line:
key, value = line.split(':', 1)
if ":" in line:
key, value = line.split(":", 1)
kv_pairs[key.strip()] = value.strip()
json.dump(kv_pairs, json_file, indent=4)

5 changes: 1 addition & 4 deletions OCR/ocr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@


def main():

segmentation_template = os.path.join(
path, "../tests/assets/form_segmention_template.png"
)
segmentation_template = os.path.join(path, "../tests/assets/form_segmention_template.png")
raw_image = os.path.join(path, "../tests/assets/form_filled.png")
labels_path = os.path.join(path, "../tests/assets/labels.json")

Expand Down
10 changes: 2 additions & 8 deletions OCR/ocr/services/image_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,17 @@


class ImageOCR:

def __init__(self, model="microsoft/trocr-base-printed"):
self.processor = TrOCRProcessor.from_pretrained(model)
self.model = VisionEncoderDecoderModel.from_pretrained(model)

def image_to_text(self, segments: dict[str, np.ndarray]) -> dict[str, str]:
digitized: dict[str, str] = {}
for label, image in segments.items():

pixel_values = self.processor(
images=image, return_tensors="pt"
).pixel_values
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values

generated_ids = self.model.generate(pixel_values)
generated_text = self.processor.batch_decode(
generated_ids, skip_special_tokens=True
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
digitized[label] = generated_text[0]

return digitized
9 changes: 2 additions & 7 deletions OCR/ocr/services/image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

class ImageSegmenter:
def __init__(self, raw_image, segmentation_template, labels):

if not os.path.isfile(raw_image) or not os.path.isfile(segmentation_template):
raise FileNotFoundError("One or more input files do not exist.")

Expand All @@ -31,16 +30,12 @@ def segment(self) -> dict[str, np.ndarray]:
indices = np.where(np.all(self.segmentation_template == color, axis=-1))
# if there are no matching pixels
if indices[0].size == 0:
raise ValueError(
f"No pixels found for color {color} in segmentation template."
)
raise ValueError(f"No pixels found for color {color} in segmentation template.")
# if there are matching pixels
if indices[0].size > 0:
# Find the x-y coordinates
y_min, y_max = indices[0].min(), indices[0].max()
x_min, x_max = indices[1].min(), indices[1].max()
# crop the area and store the image in the dict
self.segments[label] = self.raw_image[
y_min : y_max + 1, x_min : x_max + 1
]
self.segments[label] = self.raw_image[y_min : y_max + 1, x_min : x_max + 1]
return self.segments
4 changes: 4 additions & 0 deletions OCR/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
main = "ocr.main:main"

[tool.ruff]
line-length = 118
target-version = "py310"
5 changes: 1 addition & 4 deletions OCR/tests/ocr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class TestOCR:

def test_ocr_printed(self):
segmenter = ImageSegmenter(raw_image, segmentation_template, labels_path)
ocr = ImageOCR()
Expand All @@ -23,9 +22,7 @@ def test_ocr_printed(self):
assert results["nbs_cas_id"] == "SIENNA HAMPTON"

def test_ocr_handwritten(self):
segmenter = ImageSegmenter(
raw_image_handwritten, segmentation_template, labels_path
)
segmenter = ImageSegmenter(raw_image_handwritten, segmentation_template, labels_path)
ocr = ImageOCR(model="microsoft/trocr-base-handwritten")

results = ocr.image_to_text(segmenter.segment())
Expand Down
13 changes: 2 additions & 11 deletions OCR/tests/segmentation_template_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def setup(self):
self.raw_image = raw_image
self.segmentation_template = segmentation_template
self.labels_path = labels_path
self.segmenter = ImageSegmenter(
self.raw_image, self.segmentation_template, self.labels_path
)
self.segmenter = ImageSegmenter(self.raw_image, self.segmentation_template, self.labels_path)

def test_segment(self):
segments = self.segmenter.segment()
Expand All @@ -32,11 +30,6 @@ def test_segment(self):
for segment in segments.values():
assert isinstance(segment, np.ndarray)

def test_segment_shapes(self):
segments = self.segmenter.segment()
for segment in segments.values():
assert len(segment.shape) == 3

def test_segment_shapes(self):
expected_shapes = {"nbs_patient_id": (41, 376, 3), "nbs_cas_id": (57, 366, 3)}
segments = self.segmenter.segment()
Expand All @@ -46,9 +39,7 @@ def test_segment_shapes(self):
def test_no_matching_pixels(self):
segmentation_template = np.zeros((10, 10, 3), dtype=np.uint8)
cv.imwrite("no_matching_colors.png", segmentation_template)
segmenter = ImageSegmenter(
self.raw_image, "no_matching_colors.png", self.labels_path
)
segmenter = ImageSegmenter(self.raw_image, "no_matching_colors.png", self.labels_path)
with pytest.raises(ValueError):
segmenter.segment()
os.remove("no_matching_colors.png")
Expand Down

0 comments on commit c90d1e9

Please sign in to comment.