Skip to content

Commit

Permalink
mypy typecheck wsi_registration
Browse files Browse the repository at this point in the history
  • Loading branch information
Jiaqi-Lv committed Mar 21, 2024
1 parent b297067 commit 52357da
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 46 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/mypy-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,4 @@ jobs:
tiatoolbox/typing.py \
tiatoolbox/tiatoolbox.py \
tiatoolbox/utils/*.py \
tiatoolbox/tools/__init__.py \
tiatoolbox/tools/stainextract.py \
tiatoolbox/tools/pyramid.py \
tiatoolbox/tools/tissuemask.py \
tiatoolbox/tools/graph.py \
tiatoolbox/tools/stainnorm.py \
tiatoolbox/tools/stainaugment.py \
tiatoolbox/tools/patchextraction.py
tiatoolbox/tools/*.py
89 changes: 51 additions & 38 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

import cv2
import numpy as np
Expand Down Expand Up @@ -176,8 +176,8 @@ def prealignment(
if len(moving_mask.shape) != BIN_MASK_DIM:
moving_mask = moving_mask[:, :, 0]

fixed_mask = np.uint8(fixed_mask > 0)
moving_mask = np.uint8(moving_mask > 0)
fixed_mask = (fixed_mask > 0).astype(np.uint8)
moving_mask = (moving_mask > 0).astype(np.uint8)

fixed_img = np.squeeze(fixed_img)
moving_img = np.squeeze(moving_img)
Expand All @@ -195,8 +195,8 @@ def prealignment(
moving_img = exposure.rescale_intensity(img_as_float(moving_img), in_range=(0, 1))

# Resizing of fixed and moving masks so that dice can be computed
height = np.max((fixed_mask.shape[0], moving_mask.shape[0]))
width = np.max((fixed_mask.shape[1], moving_mask.shape[1]))
height = int(np.max((fixed_mask.shape[0], moving_mask.shape[0])))
width = int(np.max((fixed_mask.shape[1], moving_mask.shape[1])))
padded_fixed_mask = np.pad(
fixed_mask,
pad_width=[(0, height - fixed_mask.shape[0]), (0, width - fixed_mask.shape[1])],
Expand Down Expand Up @@ -345,15 +345,16 @@ def __init__(self: torch.nn.Module) -> None:
for i, layer in enumerate(output_layers_id)
]

def forward_hook(self: torch.nn.Module, layer_name: str) -> None:
def forward_hook(self: torch.nn.Module, layer_name: str) -> Callable:
"""Register a hook.
Args:
layer_name (str):
User-defined name for a layer.
Returns:
None
hook (Callable):
Forward hook for feature extraction.
"""

Expand Down Expand Up @@ -430,7 +431,8 @@ class DFBRegister:
def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
"""Initialize :class:`DFBRegister`."""
self.patch_size = patch_size
self.x_scale, self.y_scale = [], []
self.x_scale: list[float] = []
self.y_scale: list[float] = []
self.feature_extractor = DFBRFeatureExtractor()

# Make this function private when full pipeline is implemented.
Expand Down Expand Up @@ -505,7 +507,7 @@ def finding_match(feature_dist: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
first_min = feature_dist[seq, ind_first_min]
mask = np.zeros_like(feature_dist)
mask[seq, ind_first_min] = 1
masked = np.ma.masked_array(feature_dist, mask)
masked: np.ndarray = np.ma.masked_array(feature_dist, mask)
second_min = np.amin(masked, axis=1)
return np.array([seq, ind_first_min]).transpose(), np.array(
second_min / first_min,
Expand Down Expand Up @@ -631,7 +633,7 @@ def feature_mapping(
moving_points = (moving_points - 112.0) / 224.0

matching_points, quality = self.finding_match(feature_dist)
max_quality = np.max(quality)
max_quality: float = np.max(quality)
while np.where(quality >= max_quality)[0].shape[0] <= num_matching_points:
max_quality -= 0.01

Expand Down Expand Up @@ -697,7 +699,7 @@ def get_tissue_regions(
fixed_mask: np.ndarray,
moving_image: np.ndarray,
moving_mask: np.ndarray,
) -> tuple[np.array, np.array, np.array, np.array, IntBounds]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, IntBounds]:
"""Extract tissue region.
This function uses binary mask for extracting tissue
Expand Down Expand Up @@ -735,6 +737,10 @@ def get_tissue_regions(
moving_minc, moving_min_r, width, height = cv2.boundingRect(moving_mask)
moving_max_c, moving_max_r = moving_minc + width, moving_min_r + height

minc: int
max_c: int
min_r: int
max_r: int
minc, max_c, min_r, max_r = (
np.min([fixed_minc, moving_minc]),
np.max([fixed_max_c, moving_max_c]),
Expand Down Expand Up @@ -842,25 +848,24 @@ def filtering_matching_points(
)

# remove duplicate matching points
duplicate_ind = []
duplicate_ind: list[int] = []
unq, count = np.unique(fixed_matched_points, axis=0, return_counts=True)
repeated_points = unq[count > 1]
for repeated_point in repeated_points:
repeated_idx = np.argwhere(
np.all(fixed_matched_points == repeated_point, axis=1),
)
duplicate_ind = np.hstack([duplicate_ind, repeated_idx.ravel()])
duplicate_ind.extend(repeated_idx.ravel())

unq, count = np.unique(moving_matched_points, axis=0, return_counts=True)
repeated_points = unq[count > 1]
for repeated_point in repeated_points:
repeated_idx = np.argwhere(
np.all(moving_matched_points == repeated_point, axis=1),
)
duplicate_ind = np.hstack([duplicate_ind, repeated_idx.ravel()])
duplicate_ind.extend(repeated_idx.ravel())

if len(duplicate_ind) > 0:
duplicate_ind = duplicate_ind.astype(int)
fixed_matched_points = np.delete(
fixed_matched_points,
duplicate_ind,
Expand Down Expand Up @@ -1002,7 +1007,9 @@ def perform_dfbregister_block_wise(
right_lower_bounding_bbox,
]

fixed_matched_points, moving_matched_points, quality = [], [], []
fixed_matched_points: list[np.ndarray] = []
moving_matched_points: list[np.ndarray] = []
quality: list[np.ndarray] = []
for _index, bounding_box in enumerate(blocks_bounding_box):
fixed_block = fixed_img[
bounding_box[0] : bounding_box[1],
Expand All @@ -1028,26 +1035,26 @@ def perform_dfbregister_block_wise(
moving_block_matched_points + bounding_box_2_0,
)
quality.append(block_quality)
fixed_matched_points, moving_matched_points, quality = (
fixed_matched_points_arr, moving_matched_points_arr, quality_arr = (
np.concatenate(fixed_matched_points),
np.concatenate(moving_matched_points),
np.concatenate(quality),
)
(
fixed_matched_points,
moving_matched_points,
fixed_matched_points_arr,
moving_matched_points_arr,
_,
) = self.filtering_matching_points(
fixed_mask,
moving_mask,
fixed_matched_points,
moving_matched_points,
quality,
fixed_matched_points_arr,
moving_matched_points_arr,
quality_arr,
)

block_transform = DFBRegister.estimate_affine_transform(
fixed_matched_points,
moving_matched_points,
fixed_matched_points_arr,
moving_matched_points_arr,
)

# Apply transformation
Expand Down Expand Up @@ -1108,8 +1115,8 @@ def register(
if len(moving_mask.shape) > BIN_MASK_DIM:
moving_mask = moving_mask[:, :, 0]

fixed_mask = np.uint8(fixed_mask > 0)
moving_mask = np.uint8(moving_mask > 0)
fixed_mask = (fixed_mask > 0).astype(np.uint8)
moving_mask = (moving_mask > 0).astype(np.uint8)

# Perform Pre-alignment
if transform_initializer is None:
Expand Down Expand Up @@ -1531,9 +1538,16 @@ def get_patch_dimensions(
transform = transform * [[1, 1, 0], [1, 1, 0], [1, 1, 1]] # remove translation
transform_points = self.transform_points(points, transform)

width = np.max(transform_points[:, 0]) - np.min(transform_points[:, 0]) + 1
height = np.max(transform_points[:, 1]) - np.min(transform_points[:, 1]) + 1
width, height = np.ceil(width).astype(int), np.ceil(height).astype(int)
width = (
int(np.max(transform_points[:, 0]))
- int(np.min(transform_points[:, 0]))
+ 1
)
height = (
int(np.max(transform_points[:, 1]))
- int(np.min(transform_points[:, 1]))
+ 1
)

return (width, height)

Expand All @@ -1542,7 +1556,7 @@ def get_transformed_location(
location: tuple[int, int],
size: tuple[int, int],
level: int,
) -> tuple[int, int]:
) -> tuple[tuple[int, int], tuple[int, int]]:
"""Get corresponding location on unregistered image and the required patch size.
This function applies inverse transformation to the centre point of the region.
Expand Down Expand Up @@ -1571,16 +1585,13 @@ def get_transformed_location(
inv_transform = inv(self.transform_level0)
size_level0 = [x * (2**level) for x in size]
center_level0 = [x + size_level0[i] / 2 for i, x in enumerate(location)]
center_level0 = np.expand_dims(np.array(center_level0), axis=0)
center_level0 = self.transform_points(center_level0, inv_transform)[0]
center_level0_arr = np.expand_dims(np.array(center_level0), axis=0)
center_level0_arr = self.transform_points(center_level0_arr, inv_transform)[0]

transformed_size = self.get_patch_dimensions(size, inv_transform)
transformed_location = [
center_level0[0] - (transformed_size[0] * (2**level)) / 2,
center_level0[1] - (transformed_size[1] * (2**level)) / 2,
]
transformed_location = tuple(
np.round(x).astype(int) for x in transformed_location
transformed_location = (
int(center_level0_arr[0] - (transformed_size[0] * (2**level)) / 2),
int(center_level0_arr[1] - (transformed_size[1] * (2**level)) / 2),
)
return transformed_location, transformed_size

Expand Down Expand Up @@ -1653,6 +1664,8 @@ def read_rect(
resolution=resolution,
units=units,
)
transformed_location: tuple[int, int]
max_size: tuple[int, int]
transformed_location, max_size = self.get_transformed_location(
location,
level_size,
Expand Down

0 comments on commit 52357da

Please sign in to comment.