generated from CDCgov/template
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add image alignment and registration code (#69)
- Loading branch information
Showing
8 changed files
with
203 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .four_point_transform import FourPointTransform as FourPointTransform | ||
from .image_homography import ImageHomography as ImageHomography | ||
from .random_perspective_transform import RandomPerspectiveTransform as RandomPerspectiveTransform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image. | ||
""" | ||
|
||
from pathlib import Path | ||
import functools | ||
|
||
import numpy as np | ||
import cv2 as cv | ||
|
||
|
||
class FourPointTransform: | ||
def __init__(self, image: Path): | ||
self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE) | ||
|
||
@staticmethod | ||
def _order_points(quadrilateral: np.ndarray) -> np.ndarray: | ||
"Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left." | ||
quadrilateral = quadrilateral.reshape(4, 2) | ||
output_quad = np.zeros([4, 2]).astype(np.float32) | ||
s = quadrilateral.sum(axis=1) | ||
output_quad[0] = quadrilateral[np.argmin(s)] | ||
output_quad[2] = quadrilateral[np.argmax(s)] | ||
diff = np.diff(quadrilateral, axis=1) | ||
output_quad[1] = quadrilateral[np.argmin(diff)] | ||
output_quad[3] = quadrilateral[np.argmax(diff)] | ||
return output_quad | ||
|
||
def find_largest_contour(self): | ||
"""Compute contours for an image and find the biggest one by area.""" | ||
_, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) | ||
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours) | ||
|
||
def simplify_polygon(self, contour): | ||
"""Simplify to a polygon with (hopefully four) vertices.""" | ||
perimeter = cv.arcLength(contour, True) | ||
return cv.approxPolyDP(contour, 0.01 * perimeter, True) | ||
|
||
def dewarp(self) -> np.ndarray: | ||
biggest_contour = self.find_largest_contour() | ||
simplified = self.simplify_polygon(biggest_contour) | ||
|
||
height, width = self.image.shape | ||
destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32) | ||
|
||
M = cv.getPerspectiveTransform(self.order_points(simplified), destination) | ||
return cv.warpPerspective(self.image, M, (width, height)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import cv2 as cv | ||
|
||
|
||
class ImageHomography: | ||
def __init__(self, template: Path, match_ratio=0.3): | ||
"""Initialize the image homography pipeline with a `template` image.""" | ||
if match_ratio >= 1 or match_ratio <= 0: | ||
raise ValueError("`match_ratio` must be between 0 and 1") | ||
|
||
self.template = cv.imread(template) | ||
self.match_ratio = match_ratio | ||
self._sift = cv.SIFT_create() | ||
|
||
def estimate_self_similarity(self): | ||
"""Calibrate `match_ratio` using a self-similarity metric.""" | ||
raise NotImplementedError | ||
|
||
def compute_descriptors(self, img): | ||
"""Compute SIFT descriptors for a target `img`.""" | ||
return self._sift.detectAndCompute(img, None) | ||
|
||
def knn_match(self, descriptor_template, descriptor_query): | ||
"""Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image.""" | ||
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED) | ||
return matcher.knnMatch(descriptor_template, descriptor_query, 2) | ||
|
||
def transform_homography(self, other): | ||
"""Run the image homography pipeline against a query image.""" | ||
# find the keypoints and descriptors with SIFT | ||
kp1, descriptors1 = self.compute_descriptors(self.template) | ||
kp2, descriptors2 = self.compute_descriptors(other) | ||
|
||
knn_matches = self.knn_match(descriptors1, descriptors2) | ||
|
||
# Filter matches using the Lowe's ratio test | ||
# use an aggressive threshold here- the larger the image the more aggresively this should be filtered | ||
good_matches = [] | ||
for m, n in knn_matches: | ||
if m.distance < self.match_ratio * n.distance: | ||
good_matches.append(m) | ||
|
||
src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2) | ||
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2) | ||
|
||
M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0) | ||
|
||
return cv.warpPerspective(other, M, (self.template.shape[1], self.template.shape[0])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
""" | ||
Perspective transforms a base image between 10% and 90% distortion. | ||
""" | ||
|
||
from pathlib import Path | ||
|
||
import torchvision.transforms as transforms | ||
from PIL import Image | ||
|
||
|
||
class RandomPerspectiveTransform: | ||
"""Generate a random perspective transform based on a template `image`.""" | ||
|
||
def __init__(self, image: Path): | ||
self.image = Image.open(image) | ||
|
||
@staticmethod | ||
def _make_transform(distortion_scale: float) -> object: | ||
""" | ||
Internal function to create a composed transformer for perspective warps. | ||
This needs to be instantiated new each time in order for the RandomPerspective transformer to be truly random between repeated calls to the `transform` function. | ||
""" | ||
return transforms.Compose( | ||
[ | ||
transforms.RandomPerspective(distortion_scale=distortion_scale, p=1), | ||
transforms.ToTensor(), | ||
transforms.ToPILImage(), | ||
] | ||
) | ||
|
||
def transform(self, distortion_scale: float) -> object: | ||
"""Warp the template image with specified `distortion_scale`.""" | ||
if distortion_scale < 0 or distortion_scale >= 1: | ||
raise ValueError("`distortion_scale` must be between 0 and 1") | ||
|
||
return self._make_transform(distortion_scale)(self.image) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
import cv2 as cv | ||
import numpy as np | ||
|
||
from alignment import ImageHomography, RandomPerspectiveTransform | ||
|
||
|
||
path = os.path.dirname(__file__) | ||
|
||
template_image_path = os.path.join(path, "./assets/template_hep.jpg") | ||
filled_image_path = os.path.join(path, "./assets/form_filled_hep.jpg") | ||
filled_image = cv.imread(filled_image_path) | ||
|
||
|
||
class TestAlignment: | ||
def test_random_warp(self): | ||
transformed = RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1) | ||
assert np.median(cv.absdiff(np.array(transformed), filled_image)) > 0 | ||
|
||
def test_alignment_filled(self): | ||
aligner = ImageHomography(template_image_path) | ||
warped_image = np.array(RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1)) | ||
aligned = aligner.transform_homography(warped_image) | ||
res = cv.absdiff(aligner.template, aligned) | ||
assert aligner.template.shape == warped_image.shape | ||
assert np.median(res) == 0 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.