Skip to content

Commit

Permalink
Add image alignment and registration code (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonchang authored Jun 13, 2024
1 parent 01fcdfb commit ae4f6f0
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 1 deletion.
3 changes: 3 additions & 0 deletions OCR/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .four_point_transform import FourPointTransform as FourPointTransform
from .image_homography import ImageHomography as ImageHomography
from .random_perspective_transform import RandomPerspectiveTransform as RandomPerspectiveTransform
47 changes: 47 additions & 0 deletions OCR/alignment/four_point_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.
"""

from pathlib import Path
import functools

import numpy as np
import cv2 as cv


class FourPointTransform:
def __init__(self, image: Path):
self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE)

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
"Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left."
quadrilateral = quadrilateral.reshape(4, 2)
output_quad = np.zeros([4, 2]).astype(np.float32)
s = quadrilateral.sum(axis=1)
output_quad[0] = quadrilateral[np.argmin(s)]
output_quad[2] = quadrilateral[np.argmax(s)]
diff = np.diff(quadrilateral, axis=1)
output_quad[1] = quadrilateral[np.argmin(diff)]
output_quad[3] = quadrilateral[np.argmax(diff)]
return output_quad

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
_, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
"""Simplify to a polygon with (hopefully four) vertices."""
perimeter = cv.arcLength(contour, True)
return cv.approxPolyDP(contour, 0.01 * perimeter, True)

def dewarp(self) -> np.ndarray:
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

height, width = self.image.shape
destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)

M = cv.getPerspectiveTransform(self.order_points(simplified), destination)
return cv.warpPerspective(self.image, M, (width, height))
50 changes: 50 additions & 0 deletions OCR/alignment/image_homography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path

import numpy as np
import cv2 as cv


class ImageHomography:
def __init__(self, template: Path, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

self.template = cv.imread(template)
self.match_ratio = match_ratio
self._sift = cv.SIFT_create()

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
raise NotImplementedError

def compute_descriptors(self, img):
"""Compute SIFT descriptors for a target `img`."""
return self._sift.detectAndCompute(img, None)

def knn_match(self, descriptor_template, descriptor_query):
"""Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image."""
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED)
return matcher.knnMatch(descriptor_template, descriptor_query, 2)

def transform_homography(self, other):
"""Run the image homography pipeline against a query image."""
# find the keypoints and descriptors with SIFT
kp1, descriptors1 = self.compute_descriptors(self.template)
kp2, descriptors2 = self.compute_descriptors(other)

knn_matches = self.knn_match(descriptors1, descriptors2)

# Filter matches using the Lowe's ratio test
# use an aggressive threshold here- the larger the image the more aggresively this should be filtered
good_matches = []
for m, n in knn_matches:
if m.distance < self.match_ratio * n.distance:
good_matches.append(m)

src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)

M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)

return cv.warpPerspective(other, M, (self.template.shape[1], self.template.shape[0]))
37 changes: 37 additions & 0 deletions OCR/alignment/random_perspective_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Perspective transforms a base image between 10% and 90% distortion.
"""

from pathlib import Path

import torchvision.transforms as transforms
from PIL import Image


class RandomPerspectiveTransform:
"""Generate a random perspective transform based on a template `image`."""

def __init__(self, image: Path):
self.image = Image.open(image)

@staticmethod
def _make_transform(distortion_scale: float) -> object:
"""
Internal function to create a composed transformer for perspective warps.
This needs to be instantiated new each time in order for the RandomPerspective transformer to be truly random between repeated calls to the `transform` function.
"""
return transforms.Compose(
[
transforms.RandomPerspective(distortion_scale=distortion_scale, p=1),
transforms.ToTensor(),
transforms.ToPILImage(),
]
)

def transform(self, distortion_scale: float) -> object:
"""Warp the template image with specified `distortion_scale`."""
if distortion_scale < 0 or distortion_scale >= 1:
raise ValueError("`distortion_scale` must be between 0 and 1")

return self._make_transform(distortion_scale)(self.image)
39 changes: 38 additions & 1 deletion OCR/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions OCR/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ opencv-python = "^4.9.0.80"
python-dotenv = "^1.0.1"
transformers = {extras = ["torch"], version = "^4.39.3"}
pillow = "^10.3.0"
torchvision = "^0.18.0"

[tool.poetry.group.dev.dependencies]
ruff = "^0.3.7"
Expand Down
27 changes: 27 additions & 0 deletions OCR/tests/alignment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import cv2 as cv
import numpy as np

from alignment import ImageHomography, RandomPerspectiveTransform


path = os.path.dirname(__file__)

template_image_path = os.path.join(path, "./assets/template_hep.jpg")
filled_image_path = os.path.join(path, "./assets/form_filled_hep.jpg")
filled_image = cv.imread(filled_image_path)


class TestAlignment:
def test_random_warp(self):
transformed = RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1)
assert np.median(cv.absdiff(np.array(transformed), filled_image)) > 0

def test_alignment_filled(self):
aligner = ImageHomography(template_image_path)
warped_image = np.array(RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1))
aligned = aligner.transform_homography(warped_image)
res = cv.absdiff(aligner.template, aligned)
assert aligner.template.shape == warped_image.shape
assert np.median(res) == 0
Binary file added OCR/tests/assets/template_hep.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit ae4f6f0

Please sign in to comment.