Skip to content

Commit

Permalink
Add mypy Type Check to wsi_registration.py (#801)
Browse files Browse the repository at this point in the history
- Add `mypy` Type Check to `wsi_registration.py`
  - This adds `mypy` checks to all modules in `tools`.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Shan E Ahmed Raza <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2024
1 parent 4332fab commit 7b6f1ee
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 47 deletions.
11 changes: 2 additions & 9 deletions .github/workflows/mypy-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,5 @@ jobs:
tiatoolbox/__main__.py \
tiatoolbox/typing.py \
tiatoolbox/tiatoolbox.py \
tiatoolbox/utils/*.py \
tiatoolbox/tools/__init__.py \
tiatoolbox/tools/stainextract.py \
tiatoolbox/tools/pyramid.py \
tiatoolbox/tools/tissuemask.py \
tiatoolbox/tools/graph.py \
tiatoolbox/tools/stainnorm.py \
tiatoolbox/tools/stainaugment.py \
tiatoolbox/tools/patchextraction.py
tiatoolbox/utils/ \
tiatoolbox/tools/
87 changes: 49 additions & 38 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

import cv2
import numpy as np
Expand Down Expand Up @@ -176,8 +176,8 @@ def prealignment(
if len(moving_mask.shape) != BIN_MASK_DIM:
moving_mask = moving_mask[:, :, 0]

fixed_mask = np.uint8(fixed_mask > 0)
moving_mask = np.uint8(moving_mask > 0)
fixed_mask = (fixed_mask > 0).astype(np.uint8)
moving_mask = (moving_mask > 0).astype(np.uint8)

fixed_img = np.squeeze(fixed_img)
moving_img = np.squeeze(moving_img)
Expand All @@ -195,8 +195,8 @@ def prealignment(
moving_img = exposure.rescale_intensity(img_as_float(moving_img), in_range=(0, 1))

# Resizing of fixed and moving masks so that dice can be computed
height = np.max((fixed_mask.shape[0], moving_mask.shape[0]))
width = np.max((fixed_mask.shape[1], moving_mask.shape[1]))
height = int(np.max((fixed_mask.shape[0], moving_mask.shape[0])))
width = int(np.max((fixed_mask.shape[1], moving_mask.shape[1])))
padded_fixed_mask = np.pad(
fixed_mask,
pad_width=[(0, height - fixed_mask.shape[0]), (0, width - fixed_mask.shape[1])],
Expand Down Expand Up @@ -348,15 +348,16 @@ def __init__(self: torch.nn.Module) -> None:
for i, layer in enumerate(output_layers_id)
]

def forward_hook(self: torch.nn.Module, layer_name: str) -> None:
def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable:
"""Register a hook.
Args:
layer_name (str):
User-defined name for a layer.
Returns:
None
hook (Callable):
Forward hook for feature extraction.
"""

Expand Down Expand Up @@ -433,7 +434,8 @@ class DFBRegister:
def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
"""Initialize :class:`DFBRegister`."""
self.patch_size = patch_size
self.x_scale, self.y_scale = [], []
self.x_scale: list[float] = []
self.y_scale: list[float] = []
self.feature_extractor = DFBRFeatureExtractor()

# Make this function private when full pipeline is implemented.
Expand Down Expand Up @@ -508,7 +510,7 @@ def finding_match(feature_dist: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
first_min = feature_dist[seq, ind_first_min]
mask = np.zeros_like(feature_dist)
mask[seq, ind_first_min] = 1
masked = np.ma.masked_array(feature_dist, mask)
masked: np.ma.MaskedArray = np.ma.masked_array(feature_dist, mask)
second_min = np.amin(masked, axis=1)
return np.array([seq, ind_first_min]).transpose(), np.array(
second_min / first_min,
Expand Down Expand Up @@ -634,7 +636,7 @@ def feature_mapping(
moving_points = (moving_points - 112.0) / 224.0

matching_points, quality = self.finding_match(feature_dist)
max_quality = np.max(quality)
max_quality: float = np.max(quality)
while np.where(quality >= max_quality)[0].shape[0] <= num_matching_points:
max_quality -= 0.01

Expand Down Expand Up @@ -700,7 +702,7 @@ def get_tissue_regions(
fixed_mask: np.ndarray,
moving_image: np.ndarray,
moving_mask: np.ndarray,
) -> tuple[np.array, np.array, np.array, np.array, IntBounds]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, IntBounds]:
"""Extract tissue region.
This function uses binary mask for extracting tissue
Expand Down Expand Up @@ -738,6 +740,10 @@ def get_tissue_regions(
moving_minc, moving_min_r, width, height = cv2.boundingRect(moving_mask)
moving_max_c, moving_max_r = moving_minc + width, moving_min_r + height

minc: int
max_c: int
min_r: int
max_r: int
minc, max_c, min_r, max_r = (
np.min([fixed_minc, moving_minc]),
np.max([fixed_max_c, moving_max_c]),
Expand Down Expand Up @@ -845,25 +851,24 @@ def filtering_matching_points(
)

# remove duplicate matching points
duplicate_ind = []
duplicate_ind: list[int] = []
unq, count = np.unique(fixed_matched_points, axis=0, return_counts=True)
repeated_points = unq[count > 1]
for repeated_point in repeated_points:
repeated_idx = np.argwhere(
np.all(fixed_matched_points == repeated_point, axis=1),
)
duplicate_ind = np.hstack([duplicate_ind, repeated_idx.ravel()])
duplicate_ind.extend(repeated_idx.ravel())

unq, count = np.unique(moving_matched_points, axis=0, return_counts=True)
repeated_points = unq[count > 1]
for repeated_point in repeated_points:
repeated_idx = np.argwhere(
np.all(moving_matched_points == repeated_point, axis=1),
)
duplicate_ind = np.hstack([duplicate_ind, repeated_idx.ravel()])
duplicate_ind.extend(repeated_idx.ravel())

if len(duplicate_ind) > 0:
duplicate_ind = duplicate_ind.astype(int)
fixed_matched_points = np.delete(
fixed_matched_points,
duplicate_ind,
Expand Down Expand Up @@ -1005,7 +1010,9 @@ def perform_dfbregister_block_wise(
right_lower_bounding_bbox,
]

fixed_matched_points, moving_matched_points, quality = [], [], []
fixed_matched_points: list[np.ndarray] = []
moving_matched_points: list[np.ndarray] = []
quality: list[np.ndarray] = []
for _index, bounding_box in enumerate(blocks_bounding_box):
fixed_block = fixed_img[
bounding_box[0] : bounding_box[1],
Expand All @@ -1031,26 +1038,26 @@ def perform_dfbregister_block_wise(
moving_block_matched_points + bounding_box_2_0,
)
quality.append(block_quality)
fixed_matched_points, moving_matched_points, quality = (
fixed_matched_points_arr, moving_matched_points_arr, quality_arr = (
np.concatenate(fixed_matched_points),
np.concatenate(moving_matched_points),
np.concatenate(quality),
)
(
fixed_matched_points,
moving_matched_points,
fixed_matched_points_arr,
moving_matched_points_arr,
_,
) = self.filtering_matching_points(
fixed_mask,
moving_mask,
fixed_matched_points,
moving_matched_points,
quality,
fixed_matched_points_arr,
moving_matched_points_arr,
quality_arr,
)

block_transform = DFBRegister.estimate_affine_transform(
fixed_matched_points,
moving_matched_points,
fixed_matched_points_arr,
moving_matched_points_arr,
)

# Apply transformation
Expand Down Expand Up @@ -1111,8 +1118,8 @@ def register(
if len(moving_mask.shape) > BIN_MASK_DIM:
moving_mask = moving_mask[:, :, 0]

fixed_mask = np.uint8(fixed_mask > 0)
moving_mask = np.uint8(moving_mask > 0)
fixed_mask = (fixed_mask > 0).astype(np.uint8)
moving_mask = (moving_mask > 0).astype(np.uint8)

# Perform Pre-alignment
if transform_initializer is None:
Expand Down Expand Up @@ -1534,9 +1541,16 @@ def get_patch_dimensions(
transform = transform * [[1, 1, 0], [1, 1, 0], [1, 1, 1]] # remove translation
transform_points = self.transform_points(points, transform)

width = np.max(transform_points[:, 0]) - np.min(transform_points[:, 0]) + 1
height = np.max(transform_points[:, 1]) - np.min(transform_points[:, 1]) + 1
width, height = np.ceil(width).astype(int), np.ceil(height).astype(int)
width = (
int(np.max(transform_points[:, 0]))
- int(np.min(transform_points[:, 0]))
+ 1
)
height = (
int(np.max(transform_points[:, 1]))
- int(np.min(transform_points[:, 1]))
+ 1
)

return (width, height)

Expand All @@ -1545,7 +1559,7 @@ def get_transformed_location(
location: tuple[int, int],
size: tuple[int, int],
level: int,
) -> tuple[int, int]:
) -> tuple[tuple[int, int], tuple[int, int]]:
"""Get corresponding location on unregistered image and the required patch size.
This function applies inverse transformation to the centre point of the region.
Expand Down Expand Up @@ -1574,16 +1588,13 @@ def get_transformed_location(
inv_transform = inv(self.transform_level0)
size_level0 = [x * (2**level) for x in size]
center_level0 = [x + size_level0[i] / 2 for i, x in enumerate(location)]
center_level0 = np.expand_dims(np.array(center_level0), axis=0)
center_level0 = self.transform_points(center_level0, inv_transform)[0]
center_level0_arr = np.expand_dims(np.array(center_level0), axis=0)
center_level0_arr = self.transform_points(center_level0_arr, inv_transform)[0]

transformed_size = self.get_patch_dimensions(size, inv_transform)
transformed_location = [
center_level0[0] - (transformed_size[0] * (2**level)) / 2,
center_level0[1] - (transformed_size[1] * (2**level)) / 2,
]
transformed_location = tuple(
np.round(x).astype(int) for x in transformed_location
transformed_location = (
int(center_level0_arr[0] - (transformed_size[0] * (2**level)) / 2),
int(center_level0_arr[1] - (transformed_size[1] * (2**level)) / 2),
)
return transformed_location, transformed_size

Expand Down

0 comments on commit 7b6f1ee

Please sign in to comment.