Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tile masking for coarse offsets computation #58

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__pycache__
dist
sofima.egg-info
_version.py
72 changes: 66 additions & 6 deletions stitch_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
for every tile based on cross-correlation between tile overlaps.
"""

from typing import Mapping, Optional
from typing import Any, Tuple, Union, Mapping, Optional

import jax.numpy as jnp
import numpy as np
Expand All @@ -31,9 +31,17 @@
from sofima import flow_field
from sofima import mesh

TileXY = tuple[int, int]
MaskMap = Mapping[TileXY, np.ndarray]
Vector = Union[Tuple[int, int], Tuple[int, int, int], Union[Tuple[int], Tuple[Any, ...]]]


def _estimate_offset(
a: np.ndarray, b: np.ndarray, range_limit: float, filter_size: int = 10
a: np.ndarray,
b: np.ndarray,
range_limit: float,
filter_size: int = 10,
masks: Optional[Tuple[np.ndarray]] = None,
) -> tuple[list[float], float]:
"""Estimates the global offset vector between images 'a' and 'b'."""
# Mask areas with insufficient dynamic range.
Expand All @@ -45,6 +53,12 @@ def _estimate_offset(
ndimage.maximum_filter(b, filter_size)
- ndimage.minimum_filter(b, filter_size)
) < range_limit

# Apply custom overlap masks
if masks is not None:
a_mask |= masks[0]
b_mask |= masks[1]

mfc = flow_field.JAXMaskedXCorrWithStatsCalculator()
xo, yo, _, pr = mfc.flow_field(
a, b, pre_mask=a_mask, post_mask=b_mask, patch_size=a.shape, step=(1, 1)
Expand All @@ -58,9 +72,14 @@ def _estimate_offset_horiz(
right: np.ndarray,
range_limit: float,
filter_size: int,
masks: Optional[Tuple[np.ndarray]] = None,
) -> tuple[list[float], float]:
return _estimate_offset(
left[:, -overlap:], right[:, :overlap], range_limit, filter_size
a=left[:, -overlap:],
b=right[:, :overlap],
range_limit=range_limit,
filter_size=filter_size,
masks=masks
)


Expand All @@ -70,9 +89,14 @@ def _estimate_offset_vert(
bot: np.ndarray,
range_limit: float,
filter_size: int,
masks: Optional[Tuple[np.ndarray]],
) -> tuple[list[float], float]:
return _estimate_offset(
top[-overlap:, :], bot[:overlap, :], range_limit, filter_size
a=top[-overlap:, :],
b=bot[:overlap, :],
range_limit=range_limit,
masks=masks,
filter_size=filter_size
)


Expand All @@ -83,6 +107,7 @@ def compute_coarse_offsets(
min_range=(10, 100, 0),
min_overlap=160,
filter_size=10,
mask_map: Optional[MaskMap] = None
) -> tuple[np.ndarray, np.ndarray]:
"""Computes a coarse offset between every neighboring tile pair.

Expand All @@ -99,6 +124,9 @@ def compute_coarse_offsets(
min_overlap: minimum overlap required for the estimate to be considered
valid
filter_size: size of the filter to use when evaluating dynamic range
mask_map: map from (x, y) tile coordinates to boolean arrays (same shape as
tile images); If present, the elements of the mask evaluating to True
define the pixels that should be masked during coarse offsets estimation.

Returns:
two arrays of shape [2, 1] + yx_shape, where the dimensions are:
Expand All @@ -118,7 +146,7 @@ def compute_coarse_offsets(
tiles are set to nan.
"""

def _find_offset(estimate_fn, pre, post, overlaps, max_ortho_shift, axis):
def _find_offset(estimate_fn, pre, post, overlaps, max_ortho_shift, axis, masks=None):
def _is_valid_offset(offset, axis):
return (
abs(offset[1 - axis]) < max_ortho_shift
Expand All @@ -134,7 +162,19 @@ def _is_valid_offset(offset, axis):
max_pr = 0
estimates = []
for overlap in overlaps:
offset, pr = estimate_fn(overlap, pre, post, range_limit, filter_size)

# Mask overlaps if needed
ov_masks = None
if masks is not None:
ma = masks[0][:, -overlap:] if axis == 0 else masks[0][-overlap:, :]
mb = masks[1][:, :overlap] if axis == 0 else masks[1][:overlap, :]
# Disable ov masking if overlap region would be completely masked
ma = np.full_like(ma, fill_value=False) if np.all(ma) else ma
mb = np.full_like(mb, fill_value=False) if np.all(mb) else mb
ov_masks = (ma, mb)

offset, pr = estimate_fn(overlap, pre, post, range_limit, filter_size,
ov_masks)
offset[axis] -= overlap

# If a single peak is found, terminate search.
Expand Down Expand Up @@ -183,13 +223,23 @@ def _is_valid_offset(offset, axis):

left = tile_map[(x, y)]
right = tile_map[(x + 1, y)]

# Load and crop overlap masks
masks_x = None
if mask_map is not None:
ov_width = max(overlaps_xy[0])
ov_ma = mask_map[(x, y)][:, -ov_width:]
ov_mb = mask_map[(x + 1, y)][:, :ov_width]
masks_x = (ov_ma, ov_mb)

conn_x[:, 0, y, x] = _find_offset(
_estimate_offset_horiz,
left,
right,
overlaps_xy[0],
max(overlaps_xy[1]),
0,
masks_x
)

conn_y = np.full((2, 1, yx_shape[0], yx_shape[1]), np.nan)
Expand All @@ -200,13 +250,23 @@ def _is_valid_offset(offset, axis):

top = tile_map[(x, y)]
bot = tile_map[(x, y + 1)]

# Load and crop overlap masks
masks_y = None
if mask_map is not None:
ov_width = max(overlaps_xy[1])
ov_ma = mask_map[(x, y)][-ov_width:]
ov_mb = mask_map[(x, y + 1)][:ov_width]
masks_y = (ov_ma, ov_mb)

conn_y[:, 0, y, x] = _find_offset(
_estimate_offset_vert,
top,
bot,
overlaps_xy[1],
max(overlaps_xy[0]),
1,
masks_y
)

return conn_x, conn_y
Expand Down